-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
172 lines (147 loc) · 4.9 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
'''
mokuton.py
2.24.17
Generates abstract syntax trees from function source code and then inserts it and the
input types and output types back into MongoDB.
Dependencies: javalang, pymongo
Refer to https://docs.mongodb.com/getting-started/python/client/ for more info on pymongo
Note:
- All strings are returned as unicode
- Parser expects a syntactically correct class, so need a class template (using HelloWorld in this example) to place functions into.
'''
import javalang
import collections
from ast import nodes
from ast import nodeVect
from collections import Counter
# Convert string to numeric
def num(s):
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return s
# Get string of type of literal
def getLiteral(vals):
for v in vals:
if isinstance(v, basestring):
if not(type(num(v)).__name__.strip() in nodeVect):
global malformed
malformed = True
return type(num(v)).__name__
# Recursively construct AST
def generateAST(tree):
sub = []
curr = str(tree)
if curr in nodes:
sub.append('(')
if curr == 'Literal':
sub.append(str(getLiteral(tree.children)))
sub.append(')')
else:
sub.append(curr)
try:
for ch in tree.children:
if type(ch) == type(list()):
for e in ch:
if str(e) in nodes:
subtree = generateAST(e)
if len(subtree) > 0:
sub.extend(subtree)
elif str(ch) in nodes:
subtree = generateAST(ch)
if len(subtree) > 0:
sub.extend(subtree)
except AttributeError:
pass
sub.append(')')
return sub
return sub
# Vectorize AST
def vectorize(tree):
for i, t in enumerate(tree):
if t in nodeVect:
tree[i] = nodeVect[t]
return tree
# Inject function source into template to satisfy javalang module
def template(func):
return 'public class m{'+func+'}'
# Extracts the function AST out of the template
def extractFunc(tree):
return '('+'('.join(tree.split('(')[3:])[:-2]
# Return a 1x14 label vector
# The left seven indices represent the inputs and the right seven represent outputs
# Each position contains the number of each type
# e.g.
# public static int findFirst(int value, int idx)
# [2 0 0 0 0 0 0 1 0 0 0 0 0 0]
# The order is defined by the indx dict in the function below
def createLabel(intype, outtype):
indx = {'int':0, 'double':1, 'float':2, 'boolean':3, 'long':4, 'short':5, 'byte':6}
label = [0]*14
typ = Counter(intype).keys()
cnt = Counter(intype).values()
for t, c in zip(typ, cnt):
label[indx[t]] = c
typ = Counter(outtype).keys()
cnt = Counter(outtype).values()
for t, c in zip(typ, cnt):
label[indx[t]+7] = c
return label
def checkParen(ast):
print 'checking parentheses:'
leftCount = 0
rightCount = 0
for x in ast:
if x == '(':
leftCount += 1
if x == ')':
rightCount += 1
print '( count: ', str(leftCount),'\n ) count: ', str(rightCount)
if leftCount == rightCount:
print 'pass'
else:
print 'fail'
if __name__ == "__main__":
# Example document generated by Pakkun. Each document is a Java class containing functions with the desired type, here, numeric->numeric.
# {
# "_id" : "3697816498",
# "name" : "Bits.java",
# "path" : "/home/ubuntu/research/data/neurosyntax_data/github/java/android/platform_dalvik/dx/src/com/android/dx/util/Bits.java",
# "funcs" :
# [
# {
# "id" : "4108688843",
# "name" : "findFirst",
# "header" : "public static int findFirst(int value, int idx)",
# "intype" : [ "int", "int" ],
# "outtype" : [ "int" ],
# "source" : "public static int findFirst(int value, int idx) { value &= ~((1 << idx) - 1); int result = Integer.numberOfTrailingZeros(value); return (result == 32) ? -1 : result; }"
# }
# ]
# }
# well-formed
# code = 'public class HelloWorld{public static float add(int a, int b){a+=5; return 3.14;}}'
code = 'public class HelloWorld{public static float add(int a, int b){return 3.14+3.0;}}'
# code = 'public class HelloWorld{public static int findFirst(int value, int idx) { value &= ~((1 << idx) - 1); int result = Integer.numberOfTrailingZeros(value); return (result == 32) ? -1 : result;}}'
# malformed from unicode e.g. 0L and \"
# code = 'public class HelloWorld{public static long verifyPositive(long value, String paramName) { if (value <= 0L) { throw new IllegalArgumentException(paramName + \" > 0 required but it was \" + value); } return value; }}'
malformed = False
sample = {}
code = template(code)
try:
tree = javalang.parse.parse(code)
except javalang.parser.JavaSyntaxError:
print 'JavaSyntaxError:\n',code
sample = {}
sample['ast'] = list(generateAST(tree))
if not malformed:
# Remove empty strings, and nodes and parentheses from dumby class
sample['ast'] = [v for v in sample['ast'] if len(v) > 0][6:][:-3]
sample['astvec'] = vectorize(sample['ast'][:])
print sample
checkParen(sample['ast'])
else:
print 'malformed literals\n'