Skip to content

Add error checks for accessing Request members #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkgs/dart_mcp/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.2.3-wip

- Added error checking to required fields of all `Request` subclasses so that
they will throw helpful errors when accessed and not set.

## 0.2.2

- Refactor `ClientImplementation` and `ServerImplementation` to the shared
Expand Down
17 changes: 14 additions & 3 deletions pkgs/dart_mcp/lib/src/api/completions.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@ extension type CompleteRequest.fromMap(Map<String, Object?> _value)
///
/// In the case of a [ResourceReference], it must refer to a
/// [ResourceTemplate].
Reference get ref => _value['ref'] as Reference;
Reference get ref {
final ref = _value['ref'] as Reference?;
if (ref == null) {
throw ArgumentError('Missing ref field in $CompleteRequest.');
}
return ref;
}

/// The argument's information.
CompletionArgument get argument =>
(_value['argument'] as Map).cast<String, Object?>() as CompletionArgument;
CompletionArgument get argument {
final argument = _value['argument'] as CompletionArgument?;
if (argument == null) {
throw ArgumentError('Missing argument field in $CompleteRequest.');
}
return argument;
}
}

/// The server's response to a completion/complete request
Expand Down
17 changes: 14 additions & 3 deletions pkgs/dart_mcp/lib/src/api/initialization.dart
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ extension type InitializeRequest._fromMap(Map<String, Object?> _value)
ProtocolVersion? get protocolVersion =>
ProtocolVersion.tryParse(_value['protocolVersion'] as String);

ClientCapabilities get capabilities =>
_value['capabilities'] as ClientCapabilities;
ClientCapabilities get capabilities {
final capabilities = _value['capabilities'] as ClientCapabilities?;
if (capabilities == null) {
throw ArgumentError('Missing capabilities field in $InitializeRequest.');
}
return capabilities;
}

Implementation get clientInfo => _value['clientInfo'] as Implementation;
Implementation get clientInfo {
final clientInfo = _value['clientInfo'] as Implementation?;
if (clientInfo == null) {
throw ArgumentError('Missing clientInfo field in $InitializeRequest.');
}
return clientInfo;
}
}

/// After receiving an initialize request from the client, the server sends
Expand Down
14 changes: 12 additions & 2 deletions pkgs/dart_mcp/lib/src/api/logging.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,18 @@ extension type SetLevelRequest.fromMap(Map<String, Object?> _value)
///
/// The server should send all logs at this level and higher (i.e., more
/// severe) to the client as notifications/message.
LoggingLevel get level =>
LoggingLevel.values.firstWhere((level) => level.name == _value['level']);
LoggingLevel get level {
final levelName = _value['level'];
final foundLevel = LoggingLevel.values.firstWhereOrNull(
(level) => level.name == levelName,
);
if (foundLevel == null) {
throw ArgumentError(
"Invalid level field in $SetLevelRequest: didn't find level $levelName",
);
}
return foundLevel;
}
}

/// Notification of a log message passed from server to client.
Expand Down
8 changes: 7 additions & 1 deletion pkgs/dart_mcp/lib/src/api/prompts.dart
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ extension type GetPromptRequest.fromMap(Map<String, Object?> _value)
});

/// The name of the prompt or prompt template.
String get name => _value['name'] as String;
String get name {
final name = _value['name'] as String?;
if (name == null) {
throw ArgumentError('Missing name field in $GetPromptRequest.');
}
return name;
}

/// Arguments to use for templating the prompt.
Map<String, Object?>? get arguments =>
Expand Down
24 changes: 21 additions & 3 deletions pkgs/dart_mcp/lib/src/api/resources.dart
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ extension type ReadResourceRequest.fromMap(Map<String, Object?> _value)

/// The URI of the resource to read. The URI can use any protocol; it is
/// up to the server how to interpret it.
String get uri => _value['uri'] as String;
String get uri {
final uri = _value['uri'] as String?;
if (uri == null) {
throw ArgumentError('Missing uri field in $ReadResourceRequest.');
}
return uri;
}
}

/// The server's response to a resources/read request from the client.
Expand Down Expand Up @@ -128,7 +134,13 @@ extension type SubscribeRequest.fromMap(Map<String, Object?> _value)

/// The URI of the resource to subscribe to. The URI can use any protocol;
/// it is up to the server how to interpret it.
String get uri => _value['uri'] as String;
String get uri {
final uri = _value['uri'] as String?;
if (uri == null) {
throw ArgumentError('Missing uri field in $SubscribeRequest.');
}
return uri;
}
}

/// Sent from the client to request cancellation of resources/updated
Expand All @@ -146,7 +158,13 @@ extension type UnsubscribeRequest.fromMap(Map<String, Object?> _value)
UnsubscribeRequest.fromMap({'uri': uri, if (meta != null) '_meta': meta});

/// The URI of the resource to unsubscribe from.
String get uri => _value['uri'] as String;
String get uri {
final uri = _value['uri'] as String?;
if (uri == null) {
throw ArgumentError('Missing uri field in $UnsubscribeRequest.');
}
return uri;
}
}

/// A notification from the server to the client, informing it that a resource
Expand Down
8 changes: 7 additions & 1 deletion pkgs/dart_mcp/lib/src/api/roots.dart
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ extension type ListRootsResult.fromMap(Map<String, Object?> _value)
if (meta != null) '_meta': meta,
});

List<Root> get roots => (_value['roots'] as List).cast<Root>();
List<Root> get roots {
final roots = _value['roots'] as List?;
if (roots == null) {
throw ArgumentError('Missing roots field in $ListRootsResult.');
}
return roots.cast<Root>();
}
}

/// Represents a root directory or file that the server can operate on.
Expand Down
17 changes: 14 additions & 3 deletions pkgs/dart_mcp/lib/src/api/sampling.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ extension type CreateMessageRequest.fromMap(Map<String, Object?> _value)
});

/// The messages to send to the LLM.
List<SamplingMessage> get messages =>
(_value['messages'] as List).cast<SamplingMessage>();
List<SamplingMessage> get messages {
final messages = _value['messages'] as List?;
if (messages == null) {
throw ArgumentError('Missing messages field in $CreateMessageRequest.');
}
return messages.cast<SamplingMessage>();
}

/// The server's preferences for which model to select.
///
Expand Down Expand Up @@ -69,7 +74,13 @@ extension type CreateMessageRequest.fromMap(Map<String, Object?> _value)
/// The maximum number of tokens to sample, as requested by the server.
///
/// The client MAY choose to sample fewer tokens than requested.
int get maxTokens => _value['maxTokens'] as int;
int get maxTokens {
final maxTokens = _value['maxTokens'] as int?;
if (maxTokens == null) {
throw ArgumentError('Missing maxTokens field in $CreateMessageRequest.');
}
return maxTokens;
}

/// Note: This has no documentation in the specification or schema.
List<String>? get stopSequences =>
Expand Down
8 changes: 7 additions & 1 deletion pkgs/dart_mcp/lib/src/api/tools.dart
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ extension type CallToolRequest._fromMap(Map<String, Object?> _value)
});

/// The name of the method to invoke.
String get name => _value['name'] as String;
String get name {
final name = _value['name'] as String?;
if (name == null) {
throw ArgumentError('Missing name field in $CallToolRequest');
}
return name;
}

/// The arguments to pass to the method.
Map<String, Object?>? get arguments =>
Expand Down
13 changes: 9 additions & 4 deletions pkgs/dart_mcp/lib/src/shared.dart
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,15 @@ base class MCPBase {
void registerRequestHandler<T extends Request?, R extends Result?>(
String name,
FutureOr<R> Function(T) impl,
) => _peer.registerMethod(
name,
(Parameters p) => impl((p.value as Map?)?.cast<String, Object?>() as T),
);
) => _peer.registerMethod(name, (Parameters p) {
if (p.value != null && p.value is! Map) {
throw ArgumentError(
'Request to $name must be a Map or null. Instead, got '
'${p.value.runtimeType}',
);
}
return impl((p.value as Map?)?.cast<String, Object?>() as T);
});

/// Registers a notification handler named [name] on this server.
void registerNotificationHandler<T extends Notification?>(
Expand Down
4 changes: 2 additions & 2 deletions pkgs/dart_mcp/pubspec.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: dart_mcp
version: 0.2.2
version: 0.2.3-wip
description: A package for making MCP servers and clients.
repository: https://github.com/dart-lang/ai/tree/main/pkgs/dart_mcp
issue_tracker: https://github.com/dart-lang/ai/issues?q=is%3Aissue+is%3Aopen+label%3Apackage%3Adart_mcp
Expand All @@ -10,7 +10,7 @@ environment:
dependencies:
async: ^2.13.0
collection: ^1.19.1
json_rpc_2: '>=3.0.3 <5.0.0'
json_rpc_2: ">=3.0.3 <5.0.0"
meta: ^1.16.0
stream_channel: ^2.1.4
stream_transform: ^2.1.1
Expand Down
47 changes: 47 additions & 0 deletions pkgs/dart_mcp/test/api/api_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,51 @@ void main() {
false,
);
});

group('API object validation', () {
test('throws when required fields are missing', () {
final empty = <String, Object?>{};

// Initialization
expect(
() => (empty as InitializeRequest).capabilities,
throwsArgumentError,
);
expect(
() => (empty as InitializeRequest).clientInfo,
throwsArgumentError,
);

// Tools
expect(() => (empty as CallToolRequest).name, throwsArgumentError);

// Resources
expect(() => (empty as ReadResourceRequest).uri, throwsArgumentError);
expect(() => (empty as SubscribeRequest).uri, throwsArgumentError);
expect(() => (empty as UnsubscribeRequest).uri, throwsArgumentError);

// Roots
expect(() => (empty as ListRootsResult).roots, throwsArgumentError);

// Prompts
expect(() => (empty as GetPromptRequest).name, throwsArgumentError);

// Completions
expect(() => (empty as CompleteRequest).ref, throwsArgumentError);
expect(() => (empty as CompleteRequest).argument, throwsArgumentError);

// Logging
expect(() => (empty as SetLevelRequest).level, throwsArgumentError);

// Sampling
expect(
() => (empty as CreateMessageRequest).messages,
throwsArgumentError,
);
expect(
() => (empty as CreateMessageRequest).maxTokens,
throwsArgumentError,
);
});
});
}