Skip to content

Commit d00a285

Browse files
theotherphilmatklad
authored andcommitted
Initial implementation of Ok-wrapping
1 parent fdece91 commit d00a285

File tree

4 files changed

+136
-3
lines changed

4 files changed

+136
-3
lines changed

crates/ra_hir/src/diagnostics.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,34 @@ impl AstDiagnostic for MissingFields {
143143
ast::RecordFieldList::cast(node).unwrap()
144144
}
145145
}
146+
147+
#[derive(Debug)]
148+
pub struct MissingOkInTailExpr {
149+
pub file: HirFileId,
150+
pub expr: AstPtr<ast::Expr>,
151+
}
152+
153+
impl Diagnostic for MissingOkInTailExpr {
154+
fn message(&self) -> String {
155+
"wrap return expression in Ok".to_string()
156+
}
157+
fn file(&self) -> HirFileId {
158+
self.file
159+
}
160+
fn syntax_node_ptr(&self) -> SyntaxNodePtr {
161+
self.expr.into()
162+
}
163+
fn as_any(&self) -> &(dyn Any + Send + 'static) {
164+
self
165+
}
166+
}
167+
168+
impl AstDiagnostic for MissingOkInTailExpr {
169+
type AST = ast::Expr;
170+
171+
fn ast(&self, db: &impl HirDatabase) -> Self::AST {
172+
let root = db.parse_or_expand(self.file()).unwrap();
173+
let node = self.syntax_node_ptr().to_node(&root);
174+
ast::Expr::cast(node).unwrap()
175+
}
176+
}

crates/ra_hir/src/expr/validation.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ use ra_syntax::ast::{AstNode, RecordLit};
66
use super::{Expr, ExprId, RecordLitField};
77
use crate::{
88
adt::AdtDef,
9-
diagnostics::{DiagnosticSink, MissingFields},
9+
diagnostics::{DiagnosticSink, MissingFields, MissingOkInTailExpr},
1010
expr::AstPtr,
11-
ty::InferenceResult,
11+
ty::{InferenceResult, Ty, TypeCtor},
1212
Function, HasSource, HirDatabase, Name, Path,
1313
};
14+
use ra_syntax::ast;
1415

1516
pub(crate) struct ExprValidator<'a, 'b: 'a> {
1617
func: Function,
@@ -29,11 +30,23 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
2930

3031
pub(crate) fn validate_body(&mut self, db: &impl HirDatabase) {
3132
let body = self.func.body(db);
33+
34+
// The final expr in the function body is the whole body,
35+
// so the expression being returned is the penultimate expr.
36+
let mut penultimate_expr = None;
37+
let mut final_expr = None;
38+
3239
for e in body.exprs() {
40+
penultimate_expr = final_expr;
41+
final_expr = Some(e);
42+
3343
if let (id, Expr::RecordLit { path, fields, spread }) = e {
3444
self.validate_record_literal(id, path, fields, *spread, db);
3545
}
3646
}
47+
if let Some(e) = penultimate_expr {
48+
self.validate_results_in_tail_expr(e.0, db);
49+
}
3750
}
3851

3952
fn validate_record_literal(
@@ -87,4 +100,43 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
87100
})
88101
}
89102
}
103+
104+
fn validate_results_in_tail_expr(&mut self, id: ExprId, db: &impl HirDatabase) {
105+
let expr_ty = &self.infer[id];
106+
let func_ty = self.func.ty(db);
107+
let func_sig = func_ty.callable_sig(db).unwrap();
108+
let ret = func_sig.ret();
109+
let ret = match ret {
110+
Ty::Apply(t) => t,
111+
_ => return,
112+
};
113+
let ret_enum = match ret.ctor {
114+
TypeCtor::Adt(AdtDef::Enum(e)) => e,
115+
_ => return,
116+
};
117+
let enum_name = ret_enum.name(db);
118+
if enum_name.is_none() || enum_name.unwrap().to_string() != "Result" {
119+
return;
120+
}
121+
let params = &ret.parameters;
122+
if params.len() == 2 && &params[0] == expr_ty {
123+
let source_map = self.func.body_source_map(db);
124+
let file_id = self.func.source(db).file_id;
125+
let parse = db.parse(file_id.original_file(db));
126+
let source_file = parse.tree();
127+
let expr_syntax = source_map.expr_syntax(id);
128+
if expr_syntax.is_none() {
129+
return;
130+
}
131+
let expr_syntax = expr_syntax.unwrap();
132+
let node = expr_syntax.to_node(source_file.syntax());
133+
let ast = ast::Expr::cast(node);
134+
if ast.is_none() {
135+
return;
136+
}
137+
let ast = ast.unwrap();
138+
139+
self.sink.push(MissingOkInTailExpr { file: file_id, expr: AstPtr::new(&ast) });
140+
}
141+
}
90142
}

crates/ra_hir/src/ty.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ impl Ty {
516516
}
517517
}
518518

519-
fn callable_sig(&self, db: &impl HirDatabase) -> Option<FnSig> {
519+
pub fn callable_sig(&self, db: &impl HirDatabase) -> Option<FnSig> {
520520
match self {
521521
Ty::Apply(a_ty) => match a_ty.ctor {
522522
TypeCtor::FnPtr { .. } => Some(FnSig::from_fn_ptr_substs(&a_ty.parameters)),

crates/ra_ide_api/src/diagnostics.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ pub(crate) fn diagnostics(db: &RootDatabase, file_id: FileId) -> Vec<Diagnostic>
7575
severity: Severity::Error,
7676
fix: Some(fix),
7777
})
78+
})
79+
.on::<hir::diagnostics::MissingOkInTailExpr, _>(|d| {
80+
let node = d.ast(db);
81+
let mut builder = TextEditBuilder::default();
82+
let replacement = format!("Ok({})", node.syntax().text());
83+
builder.replace(node.syntax().text_range(), replacement);
84+
let fix = SourceChange::source_file_edit_from("wrap with ok", file_id, builder.finish());
85+
res.borrow_mut().push(Diagnostic {
86+
range: d.highlight_range(),
87+
message: d.message(),
88+
severity: Severity::Error,
89+
fix: Some(fix),
90+
})
7891
});
7992
if let Some(m) = source_binder::module_from_file_id(db, file_id) {
8093
m.diagnostics(db, &mut sink);
@@ -218,6 +231,43 @@ mod tests {
218231
assert_eq!(diagnostics.len(), 0);
219232
}
220233

234+
#[test]
235+
fn test_wrap_return_type() {
236+
let before = r#"
237+
enum Result<T, E> { Ok(T), Err(E) }
238+
struct String { }
239+
240+
fn div(x: i32, y: i32) -> Result<i32, String> {
241+
if y == 0 {
242+
return Err("div by zero".into());
243+
}
244+
x / y
245+
}
246+
"#;
247+
let after = r#"
248+
enum Result<T, E> { Ok(T), Err(E) }
249+
struct String { }
250+
251+
fn div(x: i32, y: i32) -> Result<i32, String> {
252+
if y == 0 {
253+
return Err("div by zero".into());
254+
}
255+
Ok(x / y)
256+
}
257+
"#;
258+
check_apply_diagnostic_fix(before, after);
259+
}
260+
261+
#[test]
262+
fn test_wrap_return_type_not_applicable() {
263+
let content = r#"
264+
fn foo() -> Result<String, i32> {
265+
0
266+
}
267+
"#;
268+
check_no_diagnostic(content);
269+
}
270+
221271
#[test]
222272
fn test_fill_struct_fields_empty() {
223273
let before = r"

0 commit comments

Comments
 (0)