Skip to content

Commit a8022c8

Browse files
orpuente-MSidavis
authored andcommitted
functions and operations with const eval (#2270)
Compile functions and operations const evaluating any references to symbol in an external scope.
1 parent 483d384 commit a8022c8

File tree

12 files changed

+405
-87
lines changed

12 files changed

+405
-87
lines changed

compiler/qsc_qasm3/src/ast_builder.rs

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use num_bigint::BigInt;
77

88
use qsc_ast::ast::{
99
self, Attr, Block, CallableBody, CallableDecl, CallableKind, Expr, ExprKind, FieldAssign,
10-
Ident, ImportOrExportDecl, ImportOrExportItem, Item, ItemKind, Lit, Mutability, NodeId, Pat,
11-
PatKind, Path, PathKind, QubitInit, QubitInitKind, QubitSource, Stmt, StmtKind, TopLevelNode,
12-
Ty, TyKind,
10+
FunctorExpr, FunctorExprKind, Ident, ImportOrExportDecl, ImportOrExportItem, Item, ItemKind,
11+
Lit, Mutability, NodeId, Pat, PatKind, Path, PathKind, QubitInit, QubitInitKind, QubitSource,
12+
Stmt, StmtKind, TopLevelNode, Ty, TyKind,
1313
};
1414
use qsc_data_structures::span::Span;
1515

@@ -1691,6 +1691,108 @@ pub(crate) fn build_lambda<S: AsRef<str>>(
16911691
}
16921692
}
16931693

1694+
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
1695+
pub(crate) fn build_function_or_operation(
1696+
name: String,
1697+
cargs: Vec<(String, Ty, Pat)>,
1698+
qargs: Vec<(String, Ty, Pat)>,
1699+
body: Option<Block>,
1700+
name_span: Span,
1701+
body_span: Span,
1702+
gate_span: Span,
1703+
return_type: Option<Ty>,
1704+
kind: CallableKind,
1705+
functors: Option<FunctorExpr>,
1706+
) -> Stmt {
1707+
let args = cargs
1708+
.into_iter()
1709+
.chain(qargs)
1710+
.map(|(_, _, pat)| Box::new(pat))
1711+
.collect::<Vec<_>>();
1712+
1713+
let lo = args
1714+
.iter()
1715+
.min_by_key(|x| x.span.lo)
1716+
.map(|x| x.span.lo)
1717+
.unwrap_or_default();
1718+
1719+
let hi = args
1720+
.iter()
1721+
.max_by_key(|x| x.span.hi)
1722+
.map(|x| x.span.hi)
1723+
.unwrap_or_default();
1724+
1725+
let input_pat_kind = if args.len() == 1 {
1726+
PatKind::Paren(args[0].clone())
1727+
} else {
1728+
PatKind::Tuple(args.into_boxed_slice())
1729+
};
1730+
1731+
let input_pat = Pat {
1732+
kind: Box::new(input_pat_kind),
1733+
span: Span { lo, hi },
1734+
..Default::default()
1735+
};
1736+
1737+
let return_type = if let Some(ty) = return_type {
1738+
ty
1739+
} else {
1740+
build_path_ident_ty("Unit")
1741+
};
1742+
1743+
let body = CallableBody::Block(Box::new(body.unwrap_or_else(|| Block {
1744+
id: NodeId::default(),
1745+
span: body_span,
1746+
stmts: Box::new([]),
1747+
})));
1748+
1749+
let decl = CallableDecl {
1750+
id: NodeId::default(),
1751+
span: name_span,
1752+
kind,
1753+
name: Box::new(Ident {
1754+
name: name.into(),
1755+
..Default::default()
1756+
}),
1757+
generics: Box::new([]),
1758+
input: Box::new(input_pat),
1759+
output: Box::new(return_type),
1760+
functors: functors.map(Box::new),
1761+
body: Box::new(body),
1762+
};
1763+
let item = Item {
1764+
span: gate_span,
1765+
kind: Box::new(ast::ItemKind::Callable(Box::new(decl))),
1766+
..Default::default()
1767+
};
1768+
1769+
Stmt {
1770+
kind: Box::new(StmtKind::Item(Box::new(item))),
1771+
span: gate_span,
1772+
..Default::default()
1773+
}
1774+
}
1775+
1776+
pub(crate) fn build_adj_plus_ctl_functor() -> FunctorExpr {
1777+
let adj = Box::new(FunctorExpr {
1778+
kind: Box::new(FunctorExprKind::Lit(ast::Functor::Adj)),
1779+
id: Default::default(),
1780+
span: Default::default(),
1781+
});
1782+
1783+
let ctl = Box::new(FunctorExpr {
1784+
kind: Box::new(FunctorExprKind::Lit(ast::Functor::Ctl)),
1785+
id: Default::default(),
1786+
span: Default::default(),
1787+
});
1788+
1789+
FunctorExpr {
1790+
kind: Box::new(FunctorExprKind::BinOp(ast::SetOp::Union, adj, ctl)),
1791+
id: Default::default(),
1792+
span: Default::default(),
1793+
}
1794+
}
1795+
16941796
fn build_idents(idents: &[&str]) -> Option<Box<[Ident]>> {
16951797
let idents = idents
16961798
.iter()

compiler/qsc_qasm3/src/compiler.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,22 @@ use qsc_frontend::{compile::SourceMap, error::WithSource};
1010

1111
use crate::{
1212
ast_builder::{
13-
build_arg_pat, build_array_reverse_expr, build_assignment_statement, build_barrier_call,
14-
build_binary_expr, build_call_no_params, build_call_with_param, build_call_with_params,
15-
build_cast_call_by_name, build_classical_decl, build_complex_from_expr,
16-
build_convert_call_expr, build_expr_array_expr, build_for_stmt, build_gate_call_param_expr,
17-
build_gate_call_with_params_and_callee, build_global_call_with_two_params,
18-
build_if_expr_then_block, build_if_expr_then_block_else_block,
19-
build_if_expr_then_block_else_expr, build_if_expr_then_expr_else_expr,
20-
build_implicit_return_stmt, build_indexed_assignment_statement, build_lambda,
21-
build_lit_angle_expr, build_lit_bigint_expr, build_lit_bool_expr, build_lit_complex_expr,
22-
build_lit_double_expr, build_lit_int_expr, build_lit_result_array_expr_from_bitstring,
23-
build_lit_result_expr, build_managed_qubit_alloc, build_math_call_from_exprs,
24-
build_math_call_no_params, build_measure_call, build_operation_with_stmts,
25-
build_path_ident_expr, build_qasm_import_decl, build_qasm_import_items, build_range_expr,
26-
build_reset_call, build_return_expr, build_return_unit, build_stmt_semi_from_expr,
13+
build_adj_plus_ctl_functor, build_arg_pat, build_array_reverse_expr,
14+
build_assignment_statement, build_barrier_call, build_binary_expr, build_call_no_params,
15+
build_call_with_param, build_call_with_params, build_cast_call_by_name,
16+
build_classical_decl, build_complex_from_expr, build_convert_call_expr,
17+
build_expr_array_expr, build_for_stmt, build_function_or_operation,
18+
build_gate_call_param_expr, build_gate_call_with_params_and_callee,
19+
build_global_call_with_two_params, build_if_expr_then_block,
20+
build_if_expr_then_block_else_block, build_if_expr_then_block_else_expr,
21+
build_if_expr_then_expr_else_expr, build_implicit_return_stmt,
22+
build_indexed_assignment_statement, build_lit_angle_expr, build_lit_bigint_expr,
23+
build_lit_bool_expr, build_lit_complex_expr, build_lit_double_expr, build_lit_int_expr,
24+
build_lit_result_array_expr_from_bitstring, build_lit_result_expr,
25+
build_managed_qubit_alloc, build_math_call_from_exprs, build_math_call_no_params,
26+
build_measure_call, build_operation_with_stmts, build_path_ident_expr,
27+
build_qasm_import_decl, build_qasm_import_items, build_range_expr, build_reset_call,
28+
build_return_expr, build_return_unit, build_stmt_semi_from_expr,
2729
build_stmt_semi_from_expr_with_span, build_top_level_ns_with_items, build_tuple_expr,
2830
build_unary_op_expr, build_unmanaged_qubit_alloc, build_unmanaged_qubit_alloc_array,
2931
build_while_stmt, build_wrapped_block_expr, managed_qubit_alloc_array,
@@ -627,7 +629,7 @@ impl QasmCompiler {
627629

628630
// We use the same primitives used for declaring gates, because def declarations
629631
// in QASM3 can take qubits as arguments and call quantum gates.
630-
Some(build_lambda(
632+
Some(build_function_or_operation(
631633
name,
632634
cargs,
633635
vec![],
@@ -637,6 +639,7 @@ impl QasmCompiler {
637639
stmt.span,
638640
return_type,
639641
kind,
642+
None,
640643
))
641644
}
642645

@@ -899,7 +902,7 @@ impl QasmCompiler {
899902

900903
let body = Some(self.compile_block(&stmt.body));
901904

902-
Some(build_lambda(
905+
Some(build_function_or_operation(
903906
name,
904907
cargs,
905908
qargs,
@@ -909,6 +912,7 @@ impl QasmCompiler {
909912
stmt.span,
910913
None,
911914
qsast::CallableKind::Operation,
915+
Some(build_adj_plus_ctl_functor()),
912916
))
913917
}
914918

compiler/qsc_qasm3/src/semantic/lowerer.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ impl Lowerer {
257257
gate_symbol("cx", 0, 2),
258258
gate_symbol("cy", 0, 2),
259259
gate_symbol("cz", 0, 2),
260-
gate_symbol("cp", 0, 2),
260+
gate_symbol("cp", 1, 2),
261261
gate_symbol("swap", 0, 2),
262262
gate_symbol("ccx", 0, 3),
263263
gate_symbol("cu", 4, 2),
@@ -520,7 +520,45 @@ impl Lowerer {
520520

521521
let (symbol_id, symbol) = self.try_get_existing_or_insert_err_symbol(&name, ident.span);
522522

523-
let kind = semantic::ExprKind::Ident(symbol_id);
523+
// Design Note: The end goal of this const evaluation is to be able to compile qasm
524+
// annotations as Q# attributes like `@SimulatableIntrinsic()`.
525+
//
526+
// QASM3 subroutines and gates can be recursive and capture const symbols
527+
// outside their scope. In Q#, only lambdas can capture symbols, but only
528+
// proper functions and operations can be recursive or have attributes on
529+
// them. To get both, annotations & recursive gates/functions and the
530+
// ability to capture const symbols outside the gate/function scope, we
531+
// decided to compile the gates/functions as proper Q# operations/functions
532+
// and evaluate at lowering-time all references to const symbols outside
533+
// the current gate/function scope.
534+
535+
// This is true if we are inside any gate or function scope.
536+
let is_symbol_inside_gate_or_function_scope =
537+
self.symbols.is_scope_rooted_in_gate_or_subroutine();
538+
539+
// This is true if the symbol is outside the most inner gate or function scope.
540+
let is_symbol_outside_most_inner_gate_or_function_scope = self
541+
.symbols
542+
.is_symbol_outside_most_inner_gate_or_function_scope(symbol_id);
543+
544+
let is_const_evaluation_necessary = symbol.is_const()
545+
&& is_symbol_inside_gate_or_function_scope
546+
&& is_symbol_outside_most_inner_gate_or_function_scope;
547+
548+
let kind = if is_const_evaluation_necessary {
549+
if let Some(val) = symbol.get_const_expr().const_eval(&self.symbols) {
550+
semantic::ExprKind::Lit(val)
551+
} else {
552+
self.push_semantic_error(SemanticErrorKind::ExprMustBeConst(
553+
ident.name.to_string(),
554+
ident.span,
555+
));
556+
semantic::ExprKind::Err
557+
}
558+
} else {
559+
semantic::ExprKind::Ident(symbol_id)
560+
};
561+
524562
semantic::Expr {
525563
span: ident.span,
526564
kind: Box::new(kind),

compiler/qsc_qasm3/src/semantic/symbols.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
// Licensed under the MIT License.
33

44
use core::f64;
5-
use std::rc::Rc;
6-
75
use qsc_data_structures::{index_map::IndexMap, span::Span};
86
use rustc_hash::FxHashMap;
7+
use std::rc::Rc;
98

109
use super::{
1110
ast::{Expr, ExprKind, LiteralKind},
@@ -137,6 +136,12 @@ impl Symbol {
137136
}
138137
}
139138

139+
/// Returns true if they symbol's value is a const expr.
140+
#[must_use]
141+
pub fn is_const(&self) -> bool {
142+
self.const_expr.is_some()
143+
}
144+
140145
/// Returns the value of the symbol.
141146
#[must_use]
142147
pub fn get_const_expr(&self) -> Rc<Expr> {
@@ -468,6 +473,22 @@ impl SymbolTable {
468473
None
469474
}
470475

476+
#[must_use]
477+
pub fn is_symbol_outside_most_inner_gate_or_function_scope(&self, symbol_id: SymbolId) -> bool {
478+
for scope in self.scopes.iter().rev() {
479+
if scope.id_to_symbol.contains_key(&symbol_id) {
480+
return false;
481+
}
482+
if matches!(
483+
scope.kind,
484+
ScopeKind::Gate | ScopeKind::Function | ScopeKind::Global
485+
) {
486+
return true;
487+
}
488+
}
489+
unreachable!("when the loop ends we will have visited at least the Global scope");
490+
}
491+
471492
#[must_use]
472493
pub fn is_current_scope_global(&self) -> bool {
473494
matches!(self.scopes.last(), Some(scope) if scope.kind == ScopeKind::Global)
@@ -481,6 +502,14 @@ impl SymbolTable {
481502
.any(|scope| scope.kind == ScopeKind::Function)
482503
}
483504

505+
#[must_use]
506+
pub fn is_scope_rooted_in_gate_or_subroutine(&self) -> bool {
507+
self.scopes
508+
.iter()
509+
.rev()
510+
.any(|scope| matches!(scope.kind, ScopeKind::Gate | ScopeKind::Function))
511+
}
512+
484513
#[must_use]
485514
pub fn is_scope_rooted_in_global(&self) -> bool {
486515
for scope in self.scopes.iter().rev() {

compiler/qsc_qasm3/src/tests/declaration/def.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn no_parameters_no_return() -> miette::Result<(), Vec<Report>> {
1111

1212
let qsharp = compile_qasm_stmt_to_qsharp(source)?;
1313
expect![[r#"
14-
let empty : () -> Unit = () -> {};
14+
function empty() : Unit {}
1515
"#]]
1616
.assert_eq(&qsharp);
1717
Ok(())
@@ -27,9 +27,9 @@ fn single_parameter() -> miette::Result<(), Vec<Report>> {
2727

2828
let qsharp = compile_qasm_stmt_to_qsharp(source)?;
2929
expect![[r#"
30-
let square : (Int) -> Int = (x) -> {
30+
function square(x : Int) : Int {
3131
return x * x;
32-
};
32+
}
3333
"#]]
3434
.assert_eq(&qsharp);
3535
Ok(())
@@ -45,9 +45,9 @@ fn qubit_parameter() -> miette::Result<(), Vec<Report>> {
4545

4646
let qsharp = compile_qasm_stmt_to_qsharp(source)?;
4747
expect![[r#"
48-
let square : (Qubit) => Int = (q) => {
48+
operation square(q : Qubit) : Int {
4949
return 1;
50-
};
50+
}
5151
"#]]
5252
.assert_eq(&qsharp);
5353
Ok(())
@@ -63,9 +63,9 @@ fn qubit_array_parameter() -> miette::Result<(), Vec<Report>> {
6363

6464
let qsharp = compile_qasm_stmt_to_qsharp(source)?;
6565
expect![[r#"
66-
let square : (Qubit[]) => Int = (qs) => {
66+
operation square(qs : Qubit[]) : Int {
6767
return 1;
68-
};
68+
}
6969
"#]]
7070
.assert_eq(&qsharp);
7171
Ok(())

0 commit comments

Comments
 (0)