Skip to content

Commit a3912bc

Browse files
committed
feat: more constant checks
1 parent d7f841b commit a3912bc

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

src/Ops.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ for (T, mlir_func) in (
206206

207207
splatattr = MLIR.API.$mlir_func(tt, number)
208208
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
209+
210+
parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
211+
if parent_func_op == C_NULL
212+
error("Constant must be created inside a Function Op.")
213+
end
214+
209215
cst = MLIR.IR.result(cst_op)
210216
ta = TracedRArray{$T,length(shape)}((), cst, shape)
211217
return ta
@@ -226,6 +232,12 @@ end
226232
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
227233
splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element))
228234
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
235+
236+
parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
237+
if parent_func_op == C_NULL
238+
error("Constant must be created inside a Function Op.")
239+
end
240+
229241
cst = MLIR.IR.result(cst_op)
230242
ta = TracedRArray{T,length(shape)}((), cst, shape)
231243
return ta

src/mlir/IR/Operation.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ Gets the operation that owns this operation, returning null if the operation is
6868
parent_op(operation::Operation) =
6969
Operation(API.mlirOperationGetParentOperation(operation), false)
7070

71+
"""
72+
parent_region(op)
73+
Gets the region that owns this operation.
74+
"""
75+
parent_region(operation::Operation) = parent_region(block(operation))
76+
7177
"""
7278
rmfromparent!(op)
7379
@@ -333,13 +339,14 @@ end
333339

334340
function create_operation_common_with_checks(args...; operands=nothing, kwargs...)
335341
op = create_operation_common(args...; operands, kwargs...)
336-
# if !isnothing(operands)
337-
# parent_function_op = get_parent_of_type_function_op(op)
338-
# if parent_function_op != C_NULL
339-
# function_op_region = parent_region(parent_function_op)
340-
# # TODO: add the checks
341-
# end
342-
# end
342+
if !isnothing(operands)
343+
parent_function_op = get_parent_of_type_function_op(op)
344+
if parent_function_op != C_NULL
345+
function_op_region = parent_region(parent_function_op)
346+
operand_region = parent_region.(operands)
347+
# TODO: add the checks
348+
end
349+
end
343350
return op
344351
end
345352

src/mlir/IR/Value.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,5 @@ function Base.show(io::IO, value::Value)
121121
API.mlirValuePrint(value, c_print_callback, ref)
122122
end
123123
end
124+
125+
parent_region(value::Value) = parent_region(owner(value))

0 commit comments

Comments
 (0)