Skip to content

Commit

Permalink
fix some lint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Oct 29, 2024
1 parent 41afee6 commit 0799e71
Show file tree
Hide file tree
Showing 31 changed files with 4 additions and 1 deletion.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
5 changes: 4 additions & 1 deletion onnx/defs/nn/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <algorithm>
#include <cmath>
#include <limits>

#include "onnx/common/assertions.h"
#include "onnx/defs/function.h"
Expand Down Expand Up @@ -3167,9 +3168,10 @@ ONNX_OPERATOR_SET_SCHEMA(
// An error is thrown if both attn_mask and is_causal are set.
auto* is_causal_attr = ctx.getAttribute("is_causal");
int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
float neg_inf = -std::numeric_limits<float>::infinity();
builder.Add("TempMask = ConstantOfShape(AttnBiasShape)", "value", mktensor(1))
.Add("TempMaskTri = Trilu <upper = 0> (TempMask, Zero1D)")
.Const1D("FloatInf", static_cast<float>('-inf'))
.Const1D("FloatInf", neg_inf)
.Add("CasualAttnBias = Where(TempMaskTri, AttnBiasZeros, FloatInf)")
.Const("IsCasual", is_causal)
.Add("AttnBias = Where(IsCausal, CasualAttnBias, AttnBiasZeros)")
Expand All @@ -3184,6 +3186,7 @@ ONNX_OPERATOR_SET_SCHEMA(
.Add("GQACond1 = Not(NGQACond1)")
.Add("DivNumHeads = Div(QNumHeads, KVNumHeads)")
.Add("IDivNumHeads = Cast(DivNumHeads)", "to", int_type)
.Add("RemainderNumHeads = Mod(QNumHeads, KVNumHeads)")
.Add("GQACond2 = Equal(RemainderNumHeads, Zero1D)")
.Add("GQACond = And(GQACond1, GQACond2)")
.Add("InterleaveShape = Concat <axis = 0> (One1D, IDivNumHeads, One1D, One1D)")
Expand Down

0 comments on commit 0799e71

Please sign in to comment.