Skip to content

Commit 70a7ef3

Browse files
lambdabaacommit-bot@chromium.org
authored andcommitted
Implement pointer mixture network
Here we can see the mixture network is assigning probability mass to a new, project-specific name by reference https://i.imgur.com/6Zbs2qf.png. I also took this opportunity to decrease model size targeting 100M, in line with our original size goals. My initial strategy was to implement a separate pointer network in https://dart-review.googlesource.com/c/sdk/+/117005 but having a single network that can assign probability mass across local references and vocabulary lexemes is better since 1) only one network and model file 2) no need to coalesce predictions from multiple models Change-Id: I23cfc2ece61ce30bb69785149a5a6cf1604af18d Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/121461 Commit-Queue: Ari Aye <ariaye@google.com> Reviewed-by: Brian Wilkerson <brianwilkerson@google.com>
1 parent 263bfd9 commit 70a7ef3

File tree

6 files changed

+49
-19
lines changed

6 files changed

+49
-19
lines changed

DEPS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ deps = {
427427
"packages": [
428428
{
429429
"package": "dart/language_model",
430-
"version": "EFtZ0Z5T822s4EUOOaWeiXUppRGKp5d9Z6jomJIeQYcC",
430+
"version": "9fJQZ0TrnAGQKrEtuL3-AXbUfPzYxqpN_OBHr9P4hE4C",
431431
}
432432
],
433433
"dep_type": "cipd",

pkg/analysis_server/lib/src/services/completion/dart/completion_ranking.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import 'package:analysis_server/src/services/completion/dart/language_model.dart
1212
import 'package:analyzer/dart/analysis/features.dart';
1313

1414
/// Number of lookback tokens.
15-
const int _LOOKBACK = 100;
15+
const int _LOOKBACK = 50;
1616

1717
/// Minimum probability to prioritize model-only suggestion.
1818
const double _MODEL_RELEVANCE_CUTOFF = 0.5;

pkg/analysis_server/lib/src/services/completion/dart/completion_ranking_internal.dart

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,12 @@ List<String> constructQuery(DartCompletionRequest request, int n) {
174174
size < n && token != null && !token.isEof;
175175
token = token.previous) {
176176
if (!token.isSynthetic && token is! ErrorToken) {
177+
// Omit the optional new keyword as we remove it at training time to
178+
// prevent model from suggesting it.
179+
if (token.lexeme == 'new') {
180+
continue;
181+
}
182+
177183
result.add(token.lexeme);
178184
size += 1;
179185
}

pkg/analysis_server/lib/src/services/completion/dart/language_model.dart

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import 'package:tflite_native/tflite.dart' as tfl;
1313
/// Interface to TensorFlow-based Dart language model for next-token prediction.
1414
class LanguageModel {
1515
static const _defaultCompletions = 100;
16+
static final _numeric = RegExp(r'^\d+(.\d+)?$');
1617

1718
final tfl.Interpreter _interpreter;
1819
final Map<String, int> _word2idx;
@@ -63,48 +64,71 @@ class LanguageModel {
6364
/// Predicts the next token to follow a list of precedent tokens
6465
///
6566
/// Returns a list of tokens, sorted by most probable first.
66-
List<String> predict(Iterable<String> tokens) =>
67+
List<String> predict(List<String> tokens) =>
6768
predictWithScores(tokens).keys.toList();
6869

6970
/// Predicts the next token with confidence scores.
7071
///
7172
/// Returns an ordered map of tokens to scores, sorted by most probable first.
72-
Map<String, double> predictWithScores(Iterable<String> tokens) {
73+
Map<String, double> predictWithScores(List<String> tokens) {
7374
final tensorIn = _interpreter.getInputTensors().single;
7475
tensorIn.data = _transformInput(tokens);
7576
_interpreter.invoke();
7677
final tensorOut = _interpreter.getOutputTensors().single;
77-
return _transformOutput(tensorOut.data);
78+
return _transformOutput(tensorOut.data, tokens);
7879
}
7980

8081
/// Transforms tokens to data bytes that can be used as interpreter input.
81-
List<int> _transformInput(Iterable<String> tokens) {
82+
List<int> _transformInput(List<String> tokens) {
8283
// Replace out of vocabulary tokens.
83-
final sanitizedTokens = tokens
84-
.map((token) => _word2idx.containsKey(token) ? token : '<unknown>');
85-
84+
final sanitizedTokens = tokens.map((token) {
85+
if (_word2idx.containsKey(token)) {
86+
return token;
87+
}
88+
if (_numeric.hasMatch(token)) {
89+
return '<num>';
90+
}
91+
if (_isString(token)) {
92+
return '<str>';
93+
}
94+
return '<unk>';
95+
});
8696
// Get indexes (as floats).
8797
final indexes = Float32List(lookback)
8898
..setAll(0, sanitizedTokens.map((token) => _word2idx[token].toDouble()));
89-
9099
// Get bytes
91100
return Uint8List.view(indexes.buffer);
92101
}
93102

94103
/// Transforms interpreter output data to map of tokens to scores.
95-
Map<String, double> _transformOutput(List<int> databytes) {
104+
Map<String, double> _transformOutput(
105+
List<int> databytes, List<String> tokens) {
96106
// Get bytes.
97107
final bytes = Uint8List.fromList(databytes);
98108

99109
// Get scores (as floats)
100110
final probabilities = Float32List.view(bytes.buffer);
101111

102-
// Get indexes with scores, sorted by scores (descending)
103-
final entries = probabilities.asMap().entries.toList()
112+
final scores = Map<String, double>();
113+
probabilities.asMap().forEach((k, v) {
114+
// x in 0, 1, ..., |V| - 1 correspond to specific members of the vocabulary.
115+
// x in |V|, |V| + 1, ..., |V| + 49 are pointers to reference positions along the
116+
// network input.
117+
if (k >= _idx2word.length + tokens.length) {
118+
return;
119+
}
120+
final lexeme =
121+
k < _idx2word.length ? _idx2word[k] : tokens[k - _idx2word.length];
122+
final sanitized = lexeme.replaceAll('"', '\'');
123+
scores[sanitized] = (scores[sanitized] ?? 0.0) + v;
124+
});
125+
126+
final entries = scores.entries.toList()
104127
..sort((a, b) => b.value.compareTo(a.value));
128+
return Map.fromEntries(entries.sublist(0, completions));
129+
}
105130

106-
// Get tokens with scores, limiting the length.
107-
return Map.fromEntries(entries.sublist(0, completions))
108-
.map((k, v) => MapEntry(_idx2word[k].replaceAll('"', '\''), v));
131+
bool _isString(String token) {
132+
return token.indexOf('"') != -1 || token.indexOf("'") != -1;
109133
}
110134
}

pkg/analysis_server/test/services/completion/dart/completion_ranking_test.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ void main() {
2020
final tokens =
2121
tokenize('if (list == null) { return; } for (final i = 0; i < list.');
2222
final response = await ranking.makeRequest('predict', tokens);
23-
expect(response['data']['length'], greaterThan(0.95));
23+
expect(response['data']['length'], greaterThan(0.85));
2424
});
2525
}
2626

pkg/analysis_server/test/services/completion/dart/language_model_test.dart

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import 'package:test/test.dart';
1111

1212
final directory = path.join(File.fromUri(Platform.script).parent.path, '..',
1313
'..', '..', '..', 'language_model', 'lexeme');
14-
const expectedLookback = 100;
14+
const expectedLookback = 50;
1515

1616
void main() {
1717
if (sizeOf<IntPtr>() == 4) {
@@ -47,7 +47,7 @@ void main() {
4747
final suggestions = model.predictWithScores(tokens);
4848
final best = suggestions.entries.first;
4949
expect(best.key, 'length');
50-
expect(best.value, greaterThan(0.8));
50+
expect(best.value, greaterThan(0.85));
5151
expect(suggestions, hasLength(model.completions));
5252
});
5353

0 commit comments

Comments
 (0)