2
2
3
3
from __future__ import print_function
4
4
5
- import argparse
6
5
import ast
7
6
import sys
8
-
9
-
10
- version_info = (0 , 1 , 2 )
7
+ import fileinput
8
+ import json
9
+ version_info = (0 , 1 , 3 )
11
10
__version__ = '.' .join (map (str , version_info ))
12
11
13
12
@@ -37,6 +36,9 @@ def __init__(self, reason, node, filename):
37
36
self .filename = filename
38
37
self .node = node
39
38
39
+ def toDict (self ):
40
+ return {'file' : self .filename , 'line' : self .lineno , 'message' : self .reason }
41
+
40
42
def __str__ (self ):
41
43
return "%s:%d\t %s" % (self .filename , self .lineno , self .reason )
42
44
@@ -73,14 +75,12 @@ def check_execute(self, node):
73
75
if node .func .attr == 'format' :
74
76
return IllegalLine ('str.format called on SQL query' , node , self .filename )
75
77
elif isinstance (node , ast .Name ):
76
- # now we need to figure out where that query is assigned. blargh.
77
78
assignment = find_assignment_in_context (node .id , node )
78
79
if assignment is not None :
79
80
return self .check_execute (assignment .value )
80
81
81
82
def visit_Call (self , node ):
82
83
function_name = stringify (node .func )
83
- # catch and check aliases of session.execute and cursor.execute
84
84
if function_name .lower ().endswith ('.execute' ):
85
85
try :
86
86
node .args [0 ].parent = node
@@ -94,13 +94,11 @@ def visit_Call(self, node):
94
94
self .generic_visit (node )
95
95
96
96
def visit (self , node ):
97
- """Visit a node."""
98
97
method = 'visit_' + node .__class__ .__name__
99
98
visitor = getattr (self , method , self .generic_visit )
100
99
return visitor (node )
101
100
102
101
def generic_visit (self , node ):
103
- """Called if no explicit visitor function exists for a node."""
104
102
for field , value in ast .iter_fields (node ):
105
103
if isinstance (value , list ):
106
104
for item in value :
@@ -112,34 +110,56 @@ def generic_visit(self, node):
112
110
self .visit (value )
113
111
114
112
115
- def check (filename ):
113
+ def check (filename , args ):
116
114
c = Checker (filename = filename )
117
- with open (filename , 'r' ) as fobj :
118
- try :
119
- parsed = ast .parse (fobj .read (), filename )
120
- c .visit (parsed )
121
- except Exception :
122
- raise
115
+ if filename == '-' :
116
+ fobj = sys .stdin
117
+ else :
118
+ fobj = open (filename , 'r' )
119
+
120
+ try :
121
+ parsed = ast .parse (fobj .read (), filename )
122
+ c .visit (parsed )
123
+ except Exception :
124
+ raise
123
125
return c .errors
124
126
125
127
126
- def main ():
128
+ def create_parser ():
129
+ import argparse
127
130
parser = argparse .ArgumentParser (
128
- description = 'Look for patterns in python source files that might indicate SQL injection vulnerabilities' ,
129
- epilog = 'Exit status is 0 if all files are okay, 1 if any files have an error. Errors are printed to stdout '
131
+ description = 'Look for patterns in python source files that might indicate SQL injection or other vulnerabilities' ,
132
+ epilog = 'Exit status is 0 if all files are okay, 1 if any files have an error. Found vulnerabilities are printed to standard out '
130
133
)
131
- parser .add_argument ('--version' , action = 'version' , version = '%(prog)s ' + __version__ )
132
- parser .add_argument ('files' , nargs = '+' , help = 'Files to check' )
134
+ parser .add_argument ('-v' , '--version' , action = 'version' , version = '%(prog)s ' + __version__ )
135
+ parser .add_argument ('files' , nargs = '*' , help = 'files to check or \' -\' for standard in' )
136
+ parser .add_argument ('-j' , '--json' , action = 'store_true' , help = 'print output in JSON' )
137
+ parser .add_argument ('-s' , '--stdin' , action = 'store_true' , help = 'read from standard in, passed files are ignored' )
138
+ parser .add_argument ('-q' , '--quiet' , action = 'store_true' , help = 'do not print error statistics' )
139
+
140
+ return parser
141
+
142
+
143
+ def main ():
144
+ parser = create_parser ()
133
145
args = parser .parse_args ()
134
146
147
+ if not (args .files or args .stdin ):
148
+ parser .error ('incorrect number of arguments' )
149
+ if args .stdin :
150
+ args .files = ['-' ]
151
+
135
152
errors = []
136
153
for fname in args .files :
137
- these_errors = check (fname )
138
- if these_errors :
139
- print ('\n ' .join (str (e ) for e in these_errors ))
140
- errors .extend (these_errors )
154
+ errors .extend (check (fname , args ))
141
155
if errors :
142
- print ('%d total errors' % len (errors ), file = sys .stderr )
156
+ if args .json :
157
+ print (json .dumps (map (lambda x : x .toDict (), errors ),
158
+ indent = 2 , sort_keys = True ))
159
+ else :
160
+ print ('\n ' .join (str (e ) for e in errors ))
161
+ if not args .quiet :
162
+ print ('Total errors: %d' % len (errors ), file = sys .stderr )
143
163
return 1
144
164
else :
145
165
return 0
0 commit comments