Skip to content

Commit

Permalink
Support chained assignment statements, e.g. a = b = c.
Browse files Browse the repository at this point in the history
We know when we have begun a chained assignment when we process a DUP_TOP with non-null on the stack. Push a NODE_CHAINSTORE onto the stack when this happens, and keep it 'floating' on top of the stack for all STORE_X operations until the stack is empty.
To support versions of Python <= 2.5 which use DUP_TOP in more places, I modified ROT_TWO, ROT_THREE and ROT_FOUR to get rid of NODE_CHAINSTORE on the stack if it is present.
  • Loading branch information
Aralox committed Oct 23, 2020
1 parent 1db8d28 commit 7a89b72
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 14 deletions.
18 changes: 17 additions & 1 deletion ASTNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ASTNode {
NODE_CONVERT, NODE_KEYWORD, NODE_RAISE, NODE_EXEC, NODE_BLOCK,
NODE_COMPREHENSION, NODE_LOADBUILDCLASS, NODE_AWAITABLE,
NODE_FORMATTEDVALUE, NODE_JOINEDSTR, NODE_CONST_MAP,
NODE_ANNOTATED_VAR,
NODE_ANNOTATED_VAR, NODE_CHAINSTORE,

// Empty node types
NODE_LOCALS,
Expand Down Expand Up @@ -71,11 +71,27 @@ class ASTNodeList : public ASTNode {
void removeLast();
void append(PycRef<ASTNode> node) { m_nodes.emplace_back(std::move(node)); }

protected:
ASTNodeList(list_t nodes, ASTNode::Type type)
: ASTNode(type), m_nodes(std::move(nodes)) { }

private:
list_t m_nodes;
};


class ASTChainStore : public ASTNodeList {
public:
ASTChainStore(list_t nodes, PycRef<ASTNode> src)
: ASTNodeList(nodes, NODE_CHAINSTORE), m_src(std::move(src)) { }

PycRef<ASTNode> src() const { return m_src; }

private:
PycRef<ASTNode> m_src;
};


class ASTObject : public ASTNode {
public:
ASTObject(PycRef<PycObject> obj)
Expand Down
110 changes: 100 additions & 10 deletions ASTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// NOTE: Nested f-strings not supported.
#define F_STRING_QUOTE "'''"

static void append_to_chain_store(PycRef<ASTNode> chainStore, PycRef<ASTNode> item, FastStack& stack, PycRef<ASTBlock> curblock);

/* Use this to determine if an error occurred (and therefore, if we should
* avoid cleaning the output tree) */
static bool cleanBuild;
Expand Down Expand Up @@ -288,6 +290,9 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
}
stack.push(map);
} else {
if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
stack.pop();
}
stack.push(new ASTMap());
}
break;
Expand Down Expand Up @@ -679,7 +684,22 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
}
break;
case Pyc::DUP_TOP:
stack.push(stack.top());
{
if (stack.top().type() != PycObject::TYPE_NULL) {
if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
auto chainstore = stack.top();
stack.pop();
stack.push(stack.top());
stack.push(chainstore);
} else {
stack.push(stack.top());
ASTNodeList::list_t targets;
stack.push(new ASTChainStore(targets, stack.top()));
}
} else {
stack.push(stack.top());
}
}
break;
case Pyc::DUP_TOP_TWO:
{
Expand Down Expand Up @@ -791,6 +811,9 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
break;
case Pyc::EXEC_STMT:
{
if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
stack.pop();
}
PycRef<ASTNode> loc = stack.top();
stack.pop();
PycRef<ASTNode> glob = stack.top();
Expand Down Expand Up @@ -1725,6 +1748,9 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTNode> one = stack.top();
stack.pop();
if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
stack.pop();
}
PycRef<ASTNode> two = stack.top();
stack.pop();

Expand All @@ -1738,6 +1764,9 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
stack.pop();
PycRef<ASTNode> two = stack.top();
stack.pop();
if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
stack.pop();
}
PycRef<ASTNode> three = stack.top();
stack.pop();
stack.push(one);
Expand All @@ -1753,6 +1782,9 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
stack.pop();
PycRef<ASTNode> three = stack.top();
stack.pop();
if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
stack.pop();
}
PycRef<ASTNode> four = stack.top();
stack.pop();
stack.push(one);
Expand Down Expand Up @@ -1889,17 +1921,23 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
stack.pop();
PycRef<ASTNode> seq = stack.top();
stack.pop();

curblock->append(new ASTStore(seq, tup));
if (seq.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(seq, tup, stack, curblock);
} else {
curblock->append(new ASTStore(seq, tup));
}
}
} else {
PycRef<ASTNode> name = stack.top();
stack.pop();
PycRef<ASTNode> value = stack.top();
stack.pop();
PycRef<ASTNode> attr = new ASTBinary(name, new ASTName(code->getName(operand)), ASTBinary::BIN_ATTR);

curblock->append(new ASTStore(value, attr));
if (value.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(value, attr, stack, curblock);
} else {
curblock->append(new ASTStore(value, attr));
}
}
}
break;
Expand All @@ -1919,13 +1957,22 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
PycRef<ASTNode> seq = stack.top();
stack.pop();

curblock->append(new ASTStore(seq, tup));
if (seq.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(seq, tup, stack, curblock);
} else {
curblock->append(new ASTStore(seq, tup));
}
}
} else {
PycRef<ASTNode> value = stack.top();
stack.pop();
PycRef<ASTNode> name = new ASTName(code->getCellVar(operand));
curblock->append(new ASTStore(value, name));

if (value.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(value, name, stack, curblock);
} else {
curblock->append(new ASTStore(value, name));
}
}
}
break;
Expand Down Expand Up @@ -1956,6 +2003,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
} else if (seq.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(seq, tup, stack, curblock);
} else {
curblock->append(new ASTStore(seq, tup));
}
Expand Down Expand Up @@ -1983,6 +2032,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
&& !curblock->inited()) {
curblock.cast<ASTWithBlock>()->setExpr(value);
curblock.cast<ASTWithBlock>()->setVar(name);
} else if (value.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(value, name, stack, curblock);
} else {
curblock->append(new ASTStore(value, name));
}
Expand Down Expand Up @@ -2011,14 +2062,20 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
} else if (seq.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(seq, tup, stack, curblock);
} else {
curblock->append(new ASTStore(seq, tup));
}
}
} else {
PycRef<ASTNode> value = stack.top();
stack.pop();
curblock->append(new ASTStore(value, name));
if (value.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(value, name, stack, curblock);
} else {
curblock->append(new ASTStore(value, name));
}
}

/* Mark the global as used */
Expand Down Expand Up @@ -2047,6 +2104,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
} else if (seq.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(seq, tup, stack, curblock);
} else {
curblock->append(new ASTStore(seq, tup));
}
Expand Down Expand Up @@ -2080,6 +2139,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
&& !curblock->inited()) {
curblock.cast<ASTWithBlock>()->setExpr(value);
curblock.cast<ASTWithBlock>()->setVar(name);
} else if (value.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(value, name, stack, curblock);
} else {
curblock->append(new ASTStore(value, name));

Expand Down Expand Up @@ -2157,8 +2218,11 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
stack.pop();
PycRef<ASTNode> seq = stack.top();
stack.pop();

curblock->append(new ASTStore(seq, tup));
if (seq.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(seq, tup, stack, curblock);
} else {
curblock->append(new ASTStore(seq, tup));
}
}
} else {
PycRef<ASTNode> subscr = stack.top();
Expand Down Expand Up @@ -2189,6 +2253,8 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
} else {
if (dest.type() == ASTNode::NODE_MAP) {
dest.cast<ASTMap>()->add(subscr, src);
} else if (src.type() == ASTNode::NODE_CHAINSTORE) {
append_to_chain_store(src, new ASTSubscr(dest, subscr), stack, curblock);
} else {
curblock->append(new ASTStore(src, new ASTSubscr(dest, subscr)));
}
Expand Down Expand Up @@ -2255,6 +2321,10 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
&& !curblock->inited()) {
tup->setRequireParens(true);
curblock.cast<ASTIterBlock>()->setIndex(tup);
} else if (stack.top().type() == ASTNode::NODE_CHAINSTORE) {
auto chainStore = stack.top();
stack.pop();
append_to_chain_store(chainStore, tup, stack, curblock);
} else {
curblock->append(new ASTStore(stack.top(), tup));
stack.pop();
Expand Down Expand Up @@ -2319,6 +2389,17 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
return new ASTNodeList(defblock->nodes());
}

static void append_to_chain_store(PycRef<ASTNode> chainStore, PycRef<ASTNode> item, FastStack& stack, PycRef<ASTBlock> curblock)
{
stack.pop(); // ignore identical source object.
chainStore.cast<ASTChainStore>()->append(item);
if (stack.top().type() == PycObject::TYPE_NULL) {
curblock->append(chainStore);
} else {
stack.push(chainStore);
}
}

static int cmp_prec(PycRef<ASTNode> parent, PycRef<ASTNode> child)
{
/* Determine whether the parent has higher precedence than therefore
Expand Down Expand Up @@ -3028,6 +3109,15 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
}
}
break;
case ASTNode::NODE_CHAINSTORE:
{
for (auto& dest : node.cast<ASTChainStore>()->nodes()) {
print_src(dest, mod);
fputs(" = ", pyc_output);
}
print_src(node.cast<ASTChainStore>()->src(), mod);
}
break;
case ASTNode::NODE_SUBSCR:
{
print_src(node.cast<ASTSubscr>()->name(), mod);
Expand Down
Binary file added tests/compiled/chain_assignment.2.7.pyc
Binary file not shown.
Binary file added tests/compiled/chain_assignment.3.7.pyc
Binary file not shown.
39 changes: 39 additions & 0 deletions tests/input/chain_assignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
a = [y, z] = x = (k1, k2, k3) = [] = c = myfunc(x) + 3

x = y = g = {keyA: X}

global store_global
Gx = Gy = Gz = Gq1
Gx = [Gy, Gz] = Gq2
a = b = store_global = c

def func_with_global():
global Gx, Gy, Gz, Gq
Gx = Gy = Gz = Gq

y = store_subscr[0] = x
a[0] = b[x] = c[3] = D[4]
a[0] = (b[x], c[3]) = D[4]
a[0] = Q = [b[x], c[3]] = F = D[4]
q = v = arr[a:b:c] = x

class store_attr1:
def __init__(self, a,b,c):
self.a = self.b = self.c = x
self.d = y

class store_attr2:
def __init__(self, a,b,c): self.a = (self.b, self.c) = x

a.b = c.d = e.f + g.h

def store_deref():
a = I
a = b = c = R1
a = (b, c) = R2
def store_fast():
x = a
y = b
z = c
p = q = r = s
p = [q, r] = s
43 changes: 43 additions & 0 deletions tests/tokenized/chain_assignment.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
a = ( y , z ) = x = ( k1 , k2 , k3 ) = ( ) = c = myfunc ( x ) + 3 <EOL>
x = y = g = { keyA : X } <EOL>
Gx = Gy = Gz = Gq1 <EOL>
Gx = ( Gy , Gz ) = Gq2 <EOL>
a = b = store_global = c <EOL>
def func_with_global ( ) : <EOL>
<INDENT>
global Gx , Gy , Gz <EOL>
Gx = Gy = Gz = Gq <EOL>
<OUTDENT>
y = store_subscr [ 0 ] = x <EOL>
a [ 0 ] = b [ x ] = c [ 3 ] = D [ 4 ] <EOL>
a [ 0 ] = ( b [ x ] , c [ 3 ] ) = D [ 4 ] <EOL>
a [ 0 ] = Q = ( b [ x ] , c [ 3 ] ) = F = D [ 4 ] <EOL>
q = v = arr [ a : b : c ] = x <EOL>
class store_attr1 : <EOL>
<INDENT>
def __init__ ( self , a , b , c ) : <EOL>
<INDENT>
self . a = self . b = self . c = x <EOL>
self . d = y <EOL>
<OUTDENT>
<OUTDENT>
class store_attr2 : <EOL>
<INDENT>
def __init__ ( self , a , b , c ) : <EOL>
<INDENT>
self . a = ( self . b , self . c ) = x <EOL>
<OUTDENT>
<OUTDENT>
a . b = c . d = e . f + g . h <EOL>
def store_deref ( ) : <EOL>
<INDENT>
a = I <EOL>
a = b = c = R1 <EOL>
a = ( b , c ) = R2 <EOL>
def store_fast ( ) : <EOL>
<INDENT>
x = a <EOL>
y = b <EOL>
z = c <EOL>
p = q = r = s <EOL>
p = ( q , r ) = s <EOL>
4 changes: 1 addition & 3 deletions tests/tokenized/f-string.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
var1 = 'x' <EOL>
var2 = 'y' <EOL>
x = 1.23456 <EOL>
s1 = 1.23456 <EOL>
var3 = 1.23456 <EOL>
x = s1 = var3 = 1.23456 <EOL>
a = 15 <EOL>
some_dict = { } <EOL>
some_dict [ 2 ] = 3 <EOL>
Expand Down

0 comments on commit 7a89b72

Please sign in to comment.