Skip to content

[flang] Revamp evaluate::CoarrayRef #136628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 12 additions & 26 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,20 +391,17 @@ template <typename T>
bool IsArrayElement(const Expr<T> &expr, bool intoSubstring = true,
bool skipComponents = false) {
if (auto dataRef{ExtractDataRef(expr, intoSubstring)}) {
const DataRef *ref{&*dataRef};
if (skipComponents) {
while (const Component * component{std::get_if<Component>(&ref->u)}) {
ref = &component->base();
for (const DataRef *ref{&*dataRef}; ref;) {
if (const Component * component{std::get_if<Component>(&ref->u)}) {
ref = skipComponents ? &component->base() : nullptr;
} else if (const auto *coarrayRef{std::get_if<CoarrayRef>(&ref->u)}) {
ref = &coarrayRef->base();
} else {
return std::holds_alternative<ArrayRef>(ref->u);
}
}
if (const auto *coarrayRef{std::get_if<CoarrayRef>(&ref->u)}) {
return !coarrayRef->subscript().empty();
} else {
return std::holds_alternative<ArrayRef>(ref->u);
}
} else {
return false;
}
return false;
}

template <typename A>
Expand All @@ -418,9 +415,6 @@ std::optional<NamedEntity> ExtractNamedEntity(const A &x) {
[](Component &&component) -> std::optional<NamedEntity> {
return NamedEntity{std::move(component)};
},
[](CoarrayRef &&co) -> std::optional<NamedEntity> {
return co.GetBase();
},
[](auto &&) { return std::optional<NamedEntity>{}; },
},
std::move(dataRef->u));
Expand Down Expand Up @@ -528,22 +522,14 @@ const Symbol *UnwrapWholeSymbolOrComponentDataRef(const A &x) {
// If an expression is a whole symbol or a whole component designator,
// potentially followed by an image selector, extract and return that symbol,
// else null.
const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(const DataRef &);
template <typename A>
const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(const A &x) {
if (auto dataRef{ExtractDataRef(x)}) {
if (const SymbolRef * p{std::get_if<SymbolRef>(&dataRef->u)}) {
return &p->get();
} else if (const Component * c{std::get_if<Component>(&dataRef->u)}) {
if (c->base().Rank() == 0) {
return &c->GetLastSymbol();
}
} else if (const CoarrayRef * c{std::get_if<CoarrayRef>(&dataRef->u)}) {
if (c->subscript().empty()) {
return &c->GetLastSymbol();
}
}
return UnwrapWholeSymbolOrComponentOrCoarrayRef(*dataRef);
} else {
return nullptr;
}
return nullptr;
}

// GetFirstSymbol(A%B%C[I]%D) -> A
Expand Down
3 changes: 1 addition & 2 deletions flang/include/flang/Evaluate/traverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ class Traverse {
return Combine(x.base(), x.subscript());
}
Result operator()(const CoarrayRef &x) const {
return Combine(
x.base(), x.subscript(), x.cosubscript(), x.stat(), x.team());
return Combine(x.base(), x.cosubscript(), x.stat(), x.team());
}
Result operator()(const DataRef &x) const { return visitor_(x.u); }
Result operator()(const Substring &x) const {
Expand Down
41 changes: 13 additions & 28 deletions flang/include/flang/Evaluate/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ class Component {

// A NamedEntity is either a whole Symbol or a component in an instance
// of a derived type. It may be a descriptor.
// TODO: this is basically a symbol with an optional DataRef base;
// could be used to replace Component.
class NamedEntity {
public:
CLASS_BOILERPLATE(NamedEntity)
Expand Down Expand Up @@ -239,28 +237,16 @@ class ArrayRef {
std::vector<Subscript> subscript_;
};

// R914 coindexed-named-object
// R924 image-selector, R926 image-selector-spec.
// C825 severely limits the usage of derived types with coarray ultimate
// components: they can't be pointers, allocatables, arrays, coarrays, or
// function results. They can be components of other derived types.
// Although the F'2018 Standard never prohibits multiple image-selectors
// per se in the same data-ref or designator, nor the presence of an
// image-selector after a part-ref with rank, the constraints on the
// derived types that would have be involved make it impossible to declare
// an object that could be referenced in these ways (esp. C748 & C825).
// C930 precludes having both TEAM= and TEAM_NUMBER=.
// TODO C931 prohibits the use of a coindexed object as a stat-variable.
// A coindexed data-ref. The base is represented as a general
// DataRef, but the base may not contain a CoarrayRef and may
// have rank > 0 only in an uppermost ArrayRef.
class CoarrayRef {
public:
CLASS_BOILERPLATE(CoarrayRef)
CoarrayRef(SymbolVector &&, std::vector<Subscript> &&,
std::vector<Expr<SubscriptInteger>> &&);
CoarrayRef(DataRef &&, std::vector<Expr<SubscriptInteger>> &&);

const SymbolVector &base() const { return base_; }
SymbolVector &base() { return base_; }
const std::vector<Subscript> &subscript() const { return subscript_; }
std::vector<Subscript> &subscript() { return subscript_; }
const DataRef &base() const { return base_.value(); }
DataRef &base() { return base_.value(); }
const std::vector<Expr<SubscriptInteger>> &cosubscript() const {
return cosubscript_;
}
Expand All @@ -270,25 +256,24 @@ class CoarrayRef {
// (i.e., Designator or pointer-valued FunctionRef).
std::optional<Expr<SomeInteger>> stat() const;
CoarrayRef &set_stat(Expr<SomeInteger> &&);
std::optional<Expr<SomeInteger>> team() const;
bool teamIsTeamNumber() const { return teamIsTeamNumber_; }
CoarrayRef &set_team(Expr<SomeInteger> &&, bool isTeamNumber = false);
// When team() is Expr<SomeInteger>, it's TEAM_NUMBER=; otherwise,
// it's TEAM=.
std::optional<Expr<SomeType>> team() const;
CoarrayRef &set_team(Expr<SomeType> &&);

int Rank() const;
int Corank() const { return 0; }
const Symbol &GetFirstSymbol() const;
const Symbol &GetLastSymbol() const;
NamedEntity GetBase() const;
std::optional<Expr<SubscriptInteger>> LEN() const;
bool operator==(const CoarrayRef &) const;
llvm::raw_ostream &AsFortran(llvm::raw_ostream &) const;

private:
SymbolVector base_;
std::vector<Subscript> subscript_;
common::CopyableIndirection<DataRef> base_;
std::vector<Expr<SubscriptInteger>> cosubscript_;
std::optional<common::CopyableIndirection<Expr<SomeInteger>>> stat_, team_;
bool teamIsTeamNumber_{false}; // false: TEAM=, true: TEAM_NUMBER=
std::optional<common::CopyableIndirection<Expr<SomeInteger>>> stat_;
std::optional<common::CopyableIndirection<Expr<SomeType>>> team_;
};

// R911 data-ref is defined syntactically as a series of part-refs, which
Expand Down
5 changes: 1 addition & 4 deletions flang/lib/Evaluate/check-expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,10 +946,7 @@ class IsContiguousHelper
return std::nullopt;
}
}
Result operator()(const CoarrayRef &x) const {
int rank{0};
return CheckSubscripts(x.subscript(), rank).has_value();
}
Result operator()(const CoarrayRef &x) const { return (*this)(x.base()); }
Result operator()(const Component &x) const {
if (x.base().Rank() == 0) {
return (*this)(x.GetLastSymbol());
Expand Down
13 changes: 4 additions & 9 deletions flang/lib/Evaluate/fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,17 @@ ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) {
}

CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) {
std::vector<Subscript> subscript;
for (Subscript x : coarrayRef.subscript()) {
subscript.emplace_back(FoldOperation(context, std::move(x)));
}
DataRef base{FoldOperation(context, std::move(coarrayRef.base()))};
std::vector<Expr<SubscriptInteger>> cosubscript;
for (Expr<SubscriptInteger> x : coarrayRef.cosubscript()) {
cosubscript.emplace_back(Fold(context, std::move(x)));
}
CoarrayRef folded{std::move(coarrayRef.base()), std::move(subscript),
std::move(cosubscript)};
CoarrayRef folded{std::move(base), std::move(cosubscript)};
if (std::optional<Expr<SomeInteger>> stat{coarrayRef.stat()}) {
folded.set_stat(Fold(context, std::move(*stat)));
}
if (std::optional<Expr<SomeInteger>> team{coarrayRef.team()}) {
folded.set_team(
Fold(context, std::move(*team)), coarrayRef.teamIsTeamNumber());
if (std::optional<Expr<SomeType>> team{coarrayRef.team()}) {
folded.set_team(Fold(context, std::move(*team)));
}
return folded;
}
Expand Down
26 changes: 6 additions & 20 deletions flang/lib/Evaluate/formatting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,24 +723,8 @@ llvm::raw_ostream &ArrayRef::AsFortran(llvm::raw_ostream &o) const {
}

llvm::raw_ostream &CoarrayRef::AsFortran(llvm::raw_ostream &o) const {
bool first{true};
for (const Symbol &part : base_) {
if (first) {
first = false;
} else {
o << '%';
}
EmitVar(o, part);
}
char separator{'('};
for (const auto &sscript : subscript_) {
EmitVar(o << separator, sscript);
separator = ',';
}
if (separator == ',') {
o << ')';
}
separator = '[';
base().AsFortran(o);
char separator{'['};
for (const auto &css : cosubscript_) {
EmitVar(o << separator, css);
separator = ',';
Expand All @@ -750,8 +734,10 @@ llvm::raw_ostream &CoarrayRef::AsFortran(llvm::raw_ostream &o) const {
separator = ',';
}
if (team_) {
EmitVar(
o << separator, team_, teamIsTeamNumber_ ? "TEAM_NUMBER=" : "TEAM=");
EmitVar(o << separator, team_,
std::holds_alternative<Expr<SomeInteger>>(team_->value().u)
? "TEAM_NUMBER="
: "TEAM=");
}
return o << ']';
}
Expand Down
15 changes: 1 addition & 14 deletions flang/lib/Evaluate/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,20 +891,7 @@ auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
}

auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
NamedEntity base{coarrayRef.GetBase()};
if (coarrayRef.subscript().empty()) {
return (*this)(base);
} else {
Shape shape;
int dimension{0};
for (const Subscript &ss : coarrayRef.subscript()) {
if (ss.Rank() > 0) {
shape.emplace_back(GetExtent(ss, base, dimension));
}
++dimension;
}
return shape;
}
return (*this)(coarrayRef.base());
}

auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
Expand Down
15 changes: 14 additions & 1 deletion flang/lib/Evaluate/tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ auto GetSymbolVectorHelper::operator()(const ArrayRef &x) const -> Result {
return GetSymbolVector(x.base());
}
auto GetSymbolVectorHelper::operator()(const CoarrayRef &x) const -> Result {
return x.base();
return GetSymbolVector(x.base());
}

const Symbol *GetLastTarget(const SymbolVector &symbols) {
Expand Down Expand Up @@ -1320,6 +1320,19 @@ std::optional<parser::MessageFixedText> CheckProcCompatibility(bool isCall,
return msg;
}

const Symbol *UnwrapWholeSymbolOrComponentOrCoarrayRef(const DataRef &dataRef) {
if (const SymbolRef * p{std::get_if<SymbolRef>(&dataRef.u)}) {
return &p->get();
} else if (const Component * c{std::get_if<Component>(&dataRef.u)}) {
if (c->base().Rank() == 0) {
return &c->GetLastSymbol();
}
} else if (const CoarrayRef * c{std::get_if<CoarrayRef>(&dataRef.u)}) {
return UnwrapWholeSymbolOrComponentOrCoarrayRef(c->base());
}
return nullptr;
}

// GetLastPointerSymbol()
static const Symbol *GetLastPointerSymbol(const Symbol &symbol) {
return IsPointer(GetAssociationRoot(symbol)) ? &symbol : nullptr;
Expand Down
57 changes: 14 additions & 43 deletions flang/lib/Evaluate/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,9 @@ Triplet &Triplet::set_stride(Expr<SubscriptInteger> &&expr) {
return *this;
}

CoarrayRef::CoarrayRef(SymbolVector &&base, std::vector<Subscript> &&ss,
std::vector<Expr<SubscriptInteger>> &&css)
: base_{std::move(base)}, subscript_(std::move(ss)),
cosubscript_(std::move(css)) {
CHECK(!base_.empty());
CHECK(!cosubscript_.empty());
}
CoarrayRef::CoarrayRef(
DataRef &&base, std::vector<Expr<SubscriptInteger>> &&css)
: base_{std::move(base)}, cosubscript_(std::move(css)) {}

std::optional<Expr<SomeInteger>> CoarrayRef::stat() const {
if (stat_) {
Expand All @@ -85,7 +81,7 @@ std::optional<Expr<SomeInteger>> CoarrayRef::stat() const {
}
}

std::optional<Expr<SomeInteger>> CoarrayRef::team() const {
std::optional<Expr<SomeType>> CoarrayRef::team() const {
if (team_) {
return team_.value().value();
} else {
Expand All @@ -99,16 +95,18 @@ CoarrayRef &CoarrayRef::set_stat(Expr<SomeInteger> &&v) {
return *this;
}

CoarrayRef &CoarrayRef::set_team(Expr<SomeInteger> &&v, bool isTeamNumber) {
CHECK(IsVariable(v));
CoarrayRef &CoarrayRef::set_team(Expr<SomeType> &&v) {
team_.emplace(std::move(v));
teamIsTeamNumber_ = isTeamNumber;
return *this;
}

const Symbol &CoarrayRef::GetFirstSymbol() const { return base_.front(); }
const Symbol &CoarrayRef::GetFirstSymbol() const {
return base().GetFirstSymbol();
}

const Symbol &CoarrayRef::GetLastSymbol() const { return base_.back(); }
const Symbol &CoarrayRef::GetLastSymbol() const {
return base().GetLastSymbol();
}

void Substring::SetBounds(std::optional<Expr<SubscriptInteger>> &lower,
std::optional<Expr<SubscriptInteger>> &upper) {
Expand Down Expand Up @@ -426,17 +424,7 @@ int ArrayRef::Rank() const {
}
}

int CoarrayRef::Rank() const {
if (!subscript_.empty()) {
int rank{0};
for (const auto &expr : subscript_) {
rank += expr.Rank();
}
return rank;
} else {
return base_.back()->Rank();
}
}
int CoarrayRef::Rank() const { return base().Rank(); }

int DataRef::Rank() const {
return common::visit(common::visitors{
Expand Down Expand Up @@ -671,22 +659,6 @@ std::optional<DynamicType> Designator<T>::GetType() const {
return std::nullopt;
}

static NamedEntity AsNamedEntity(const SymbolVector &x) {
CHECK(!x.empty());
NamedEntity result{x.front()};
int j{0};
for (const Symbol &symbol : x) {
if (j++ != 0) {
DataRef base{result.IsSymbol() ? DataRef{result.GetLastSymbol()}
: DataRef{result.GetComponent()}};
result = NamedEntity{Component{std::move(base), symbol}};
}
}
return result;
}

NamedEntity CoarrayRef::GetBase() const { return AsNamedEntity(base_); }

// Equality testing

// For the purposes of comparing type parameter expressions while
Expand Down Expand Up @@ -759,9 +731,8 @@ bool ArrayRef::operator==(const ArrayRef &that) const {
return base_ == that.base_ && subscript_ == that.subscript_;
}
bool CoarrayRef::operator==(const CoarrayRef &that) const {
return base_ == that.base_ && subscript_ == that.subscript_ &&
cosubscript_ == that.cosubscript_ && stat_ == that.stat_ &&
team_ == that.team_ && teamIsTeamNumber_ == that.teamIsTeamNumber_;
return base_ == that.base_ && cosubscript_ == that.cosubscript_ &&
stat_ == that.stat_ && team_ == that.team_;
}
bool DataRef::operator==(const DataRef &that) const {
return TestVariableEquality(*this, that);
Expand Down
Loading
Loading