-
Notifications
You must be signed in to change notification settings - Fork 249
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support multi-line code completions
- Loading branch information
1 parent
6fef9c2
commit f84a659
Showing
23 changed files
with
650 additions
and
347 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 82 additions & 10 deletions
92
codegpt-treesitter/src/main/java/ee/carlrobert/codegpt/treesitter/CodeCompletionParser.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,116 @@ | ||
package ee.carlrobert.codegpt.treesitter; | ||
|
||
import java.nio.charset.StandardCharsets; | ||
import org.treesitter.TSInputEdit; | ||
import org.treesitter.TSLanguage; | ||
import org.treesitter.TSParser; | ||
import org.treesitter.TSPoint; | ||
import org.treesitter.TSTree; | ||
|
||
public class CodeCompletionParser { | ||
|
||
protected final TSLanguage language; | ||
private final TSParser parser; | ||
|
||
public CodeCompletionParser(TSLanguage language) { | ||
this.language = language; | ||
parser = new TSParser(); | ||
parser.setLanguage(language); | ||
} | ||
|
||
public String parse(String prefix, String suffix, String output) { | ||
var result = new StringBuilder(output); | ||
String input = prefix + result + suffix; | ||
TSTree currentTree = parser.parseString(null, input); | ||
|
||
while (!result.isEmpty()) { | ||
if (containsError(prefix + result + suffix)) { | ||
if (containsError(currentTree)) { | ||
int deletionIndex = prefix.length() + result.length() - 1; | ||
Position pos = getPosition(input, deletionIndex); | ||
|
||
int startByte = pos.byteOffset; | ||
int oldEndByte = startByte + getByteLength(result.substring(result.length() - 1)); | ||
|
||
TSPoint startPoint = pos.point; | ||
TSPoint oldEndPoint = computeOldEndPoint(startPoint, result.charAt(result.length() - 1)); | ||
|
||
currentTree.edit( | ||
new TSInputEdit(startByte, oldEndByte, startByte, startPoint, oldEndPoint, startPoint)); | ||
|
||
result.deleteCharAt(result.length() - 1); | ||
|
||
if (result.length() > 1 && result.charAt(result.length() - 1) == '{') { | ||
long bracketCount = result.chars().filter(ch -> ch == '{').count(); | ||
if (bracketCount == 1) { | ||
var newTree = parser.parseString(currentTree, prefix + result + "}" + suffix); | ||
var treeString = newTree.getRootNode().toString(); | ||
if (!treeString.contains("ERROR")) { | ||
return result + "}"; | ||
} | ||
} | ||
} | ||
|
||
input = prefix + result + suffix; | ||
|
||
currentTree = parser.parseString(currentTree, input); | ||
} else { | ||
return result.toString(); | ||
} | ||
} | ||
|
||
if (output.contains("\n")) { | ||
return parse(prefix, suffix, output.substring(0, output.indexOf("\n"))); | ||
var finalResult = output.substring(0, output.indexOf("\n")); | ||
if (finalResult.charAt(finalResult.length() - 1) == '{') { | ||
return finalResult + "}"; | ||
} | ||
return finalResult; | ||
} | ||
|
||
return output; | ||
} | ||
|
||
private boolean containsError(String input) { | ||
var treeString = getTree(input).getRootNode().toString(); | ||
private boolean containsError(TSTree tree) { | ||
var treeString = tree.getRootNode().toString(); | ||
return treeString.contains("ERROR") | ||
|| treeString.contains("MISSING \"}\"") | ||
|| treeString.contains("MISSING \")\""); | ||
} | ||
|
||
private TSTree getTree(String input) { | ||
var parser = new TSParser(); | ||
parser.setLanguage(language); | ||
return parser.parseString(null, input); | ||
private Position getPosition(String input, int index) { | ||
int row = 0; | ||
int col = 0; | ||
int byteOffset = 0; | ||
for (int i = 0; i < index; i++) { | ||
char c = input.charAt(i); | ||
int charByteLength = getByteLength(String.valueOf(c)); | ||
byteOffset += charByteLength; | ||
|
||
if (c == '\n') { | ||
row++; | ||
col = 0; | ||
} else { | ||
col++; | ||
} | ||
} | ||
return new Position(new TSPoint(row, col), byteOffset); | ||
} | ||
|
||
private int getByteLength(String str) { | ||
return str.getBytes(StandardCharsets.UTF_8).length; | ||
} | ||
|
||
private TSPoint computeOldEndPoint(TSPoint startPoint, char deletedChar) { | ||
int row = startPoint.getRow(); | ||
int col = startPoint.getColumn(); | ||
|
||
if (deletedChar == '\n') { | ||
row++; | ||
col = 0; | ||
} else { | ||
col++; | ||
} | ||
return new TSPoint(row, col); | ||
} | ||
|
||
private record Position(TSPoint point, int byteOffset) { | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.