Merge lp:~ben-hutchings/ensoft-sextant/csv-upload into lp:ensoft-sextant
- csv-upload
- Merge into whiteline
Status: | Merged | ||||
---|---|---|---|---|---|
Approved by: | Robert | ||||
Approved revision: | 62 | ||||
Merged at revision: | 30 | ||||
Proposed branch: | lp:~ben-hutchings/ensoft-sextant/csv-upload | ||||
Merge into: | lp:ensoft-sextant | ||||
Diff against target: |
3573 lines (+1945/-1015) 16 files modified
src/sextant/__main__.py (+54/-232) src/sextant/csvwriter.py (+152/-0) src/sextant/db_api.py (+587/-302) src/sextant/export.py (+3/-3) src/sextant/objdump_parser.py (+302/-262) src/sextant/query.py (+29/-31) src/sextant/sshmanager.py (+278/-0) src/sextant/test_all.sh (+4/-0) src/sextant/test_csvwriter.py (+89/-0) src/sextant/test_db_api.py (+68/-54) src/sextant/test_parser.py (+85/-0) src/sextant/test_resources/parser_test.c (+57/-0) src/sextant/test_resources/parser_test.dump (+44/-0) src/sextant/test_sshmanager.py (+72/-0) src/sextant/update_db.py (+96/-62) src/sextant/web/server.py (+25/-69) |
||||
To merge this branch: | bzr merge lp:~ben-hutchings/ensoft-sextant/csv-upload | ||||
Related bugs: |
|
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Robert | Approve | ||
Review via email: mp+239356@code.launchpad.net |
Commit message
Programs now upload by first being parsed into csv files, then uploaded from these to the database. This is _significantly_ faster for large programs.
Furthermore, the structure of the program nodes in the database has been changed - whereas before they were unlabelled nodes with type 'program', they are now associated with the 'program' label (the database partitions on label - so this labelling keeps programs distinct from the functions). All queries in sextant have been updated to reflect this.
New module sshmanager handles the ssh connection to the database server.
New module csvwriter deals with the nuts and bolts of the csv files.
Description of the change
Programs now upload by first being parsed into csv files, then uploaded from these to the database. This is _significantly_ faster for large programs.
Furthermore, the structure of the program nodes in the database has been changed - whereas before they were unlabelled nodes with type 'program', they are now associated with the 'program' label (the database partitions on label - so this labelling keeps programs distinct from the functions). All queries in sextant have been updated to reflect this.
New module sshmanager handles the ssh connection to the database server.
New module csvwriter deals with the nuts and bolts of the csv files.
Robert (rjwills) : | # |
Preview Diff
1 | === modified file 'src/sextant/__main__.py' |
2 | --- src/sextant/__main__.py 2014-10-03 13:00:52 +0000 |
3 | +++ src/sextant/__main__.py 2014-10-23 12:33:12 +0000 |
4 | @@ -9,7 +9,6 @@ |
5 | |
6 | import io |
7 | import sys |
8 | -import random |
9 | import socket |
10 | import logging |
11 | import logging.config |
12 | @@ -28,10 +27,12 @@ |
13 | from . import db_api |
14 | from . import update_db |
15 | from . import environment |
16 | +from . import sshmanager |
17 | |
18 | config = environment.load_config() |
19 | |
20 | |
21 | + |
22 | def _displayable_url(args): |
23 | """ |
24 | Return the URL specified by the user for Sextant to look at. |
25 | @@ -56,7 +57,7 @@ |
26 | |
27 | # Beginning of functions which handle the actual invocation of Sextant |
28 | |
29 | -def _start_web(args): |
30 | +def _start_web(connection, args): |
31 | # Don't import at top level - makes twisted dependency semi-optional, |
32 | # allowing non-web functionality to work with Python 3. |
33 | if sys.version_info[0] == 2: |
34 | @@ -68,12 +69,12 @@ |
35 | logging.info("Serving site on port {}".format(args.port)) |
36 | |
37 | # server is .web.server, imported a couple of lines ago |
38 | - server.serve_site(input_database_url=args.remote_neo4j, port=args.port) |
39 | - |
40 | - |
41 | -def _audit(args): |
42 | + server.serve_site(connection, args.port) |
43 | + |
44 | + |
45 | +def _audit(connection, args): |
46 | try: |
47 | - audited = query.audit(args.remote_neo4j) |
48 | + audited = query.audit(connection) |
49 | except requests.exceptions.ConnectionError as e: |
50 | msg = 'Connection error to server {url}: {exception}' |
51 | logging.error(msg.format(url=_displayable_url(args), exception=e)) |
52 | @@ -87,8 +88,8 @@ |
53 | titles = ("Name", "#Func", "Uploader", "User-ID", "Upload Date") |
54 | colminlens = (len(entry) for entry in titles) |
55 | # maximum lengths to avoid one entry from throwing the whole table |
56 | - # date format is <YYYY:MM:DD HH:MM:SS.UUUUUU> = 26 characters |
57 | - COLMAXLENS = (25, 5, 25, 10, 26) |
58 | + # date format is <YYYY-MM-DD HH:MM:SS> = 19 characters |
59 | + COLMAXLENS = (25, 6, 25, 10, 19) |
60 | |
61 | # make a table of the strings of each data entry we will display |
62 | text = [map(str, (p.program_name, p.number_of_funcs, |
63 | @@ -120,7 +121,7 @@ |
64 | print('\n'.join(st.format(*pentry) for pentry in text)) |
65 | |
66 | |
67 | -def _add_program(args): |
68 | +def _add_program(connection, args): |
69 | try: |
70 | alternative_name = args.name_in_db[0] |
71 | except TypeError: |
72 | @@ -131,12 +132,11 @@ |
73 | # unsupplied |
74 | |
75 | try: |
76 | - update_db.upload_program(user_name=getpass.getuser(), |
77 | - file_path=args.input_file, |
78 | - db_url=args.remote_neo4j, |
79 | - alternative_name=alternative_name, |
80 | - not_object_file=not_object_file, |
81 | - display_url=_displayable_url(args)) |
82 | + update_db.upload_program(connection, |
83 | + getpass.getuser(), |
84 | + args.input_file, |
85 | + alternative_name, |
86 | + not_object_file) |
87 | except requests.exceptions.ConnectionError as e: |
88 | msg = 'Connection error to server {}: {}' |
89 | logging.error(msg.format(_displayable_url(args), e)) |
90 | @@ -147,41 +147,41 @@ |
91 | logging.error('Input file {} was not found.'.format(args.input_file[0])) |
92 | logging.error(e) |
93 | logging.debug(e, exc_info=True) |
94 | - except ValueError as e: |
95 | + except (ValueError, sshmanager.SSHConnectionError) as e: |
96 | logging.error(e) |
97 | |
98 | |
99 | -def _delete_program(namespace): |
100 | - update_db.delete_program(namespace.program_name, |
101 | - namespace.remote_neo4j) |
102 | - |
103 | - |
104 | -def _make_query(namespace): |
105 | +def _delete_program(connection, args): |
106 | + update_db.delete_program(connection, args.program_name) |
107 | + |
108 | + |
109 | +def _make_query(connection, args): |
110 | arg1 = None |
111 | arg2 = None |
112 | try: |
113 | - arg1 = namespace.funcs[0] |
114 | - arg2 = namespace.funcs[1] |
115 | + arg1 = args.funcs[0] |
116 | + arg2 = args.funcs[1] |
117 | except TypeError: |
118 | pass |
119 | except IndexError: |
120 | pass |
121 | |
122 | try: |
123 | - program_name = namespace.program[0] |
124 | + program_name = args.program[0] |
125 | except TypeError: |
126 | program_name = None |
127 | |
128 | try: |
129 | - suppress_common = namespace.suppress_common[0] |
130 | + suppress_common = args.suppress_common[0] |
131 | except TypeError: |
132 | suppress_common = False |
133 | |
134 | - query.query(remote_neo4j=namespace.remote_neo4j, |
135 | - display_neo4j=_displayable_url(namespace), |
136 | - input_query=namespace.query, |
137 | + query.query(remote_neo4j=args.remote_neo4j, |
138 | + display_neo4j=_displayable_url(args), |
139 | + input_query=args.query, |
140 | program_name=program_name, |
141 | - argument_1=arg1, argument_2=arg2, |
142 | + argument_1=arg1, |
143 | + argument_2=arg2, |
144 | suppress_common=suppress_common) |
145 | |
146 | # End of functions which invoke Sextant |
147 | @@ -197,8 +197,10 @@ |
148 | |
149 | """ |
150 | |
151 | - argumentparser = argparse.ArgumentParser(prog='sextant', usage='sextant', description="Invoke part of the SEXTANT program") |
152 | - subparsers = argumentparser.add_subparsers(title="subcommands") |
153 | + ap = argparse.ArgumentParser(prog='sextant', |
154 | + usage='sextant', |
155 | + description="Invoke part of the SEXTANT program") |
156 | + subparsers = ap.add_subparsers(title="subcommands") |
157 | |
158 | #set what will be defined as a "common function" |
159 | db_api.set_common_cutoff(config.common_cutoff) |
160 | @@ -257,10 +259,9 @@ |
161 | parsers[key].add_argument('--remote-neo4j', metavar="URL", |
162 | help="URL of neo4j server", type=str, |
163 | default=config.remote_neo4j) |
164 | - parsers[key].add_argument('--use-ssh-tunnel', metavar="BOOL", type=str, |
165 | - help="whether to SSH into the remote server," |
166 | - "True/False", |
167 | - default=str(config.use_ssh_tunnel)) |
168 | + parsers[key].add_argument('--no-ssh-tunnel', |
169 | + help='Disable ssh tunnelling. Prevents program upload.', |
170 | + action='store_true') |
171 | parsers[key].add_argument('--ssh-user', metavar="NAME", type=str, |
172 | help="username to use as remote SSH name", |
173 | default=str(config.ssh_user)) |
174 | @@ -273,207 +274,28 @@ |
175 | |
176 | # parse the arguments |
177 | |
178 | - return argumentparser.parse_args() |
179 | - |
180 | - |
181 | -def _start_tunnel(local_port, remote_host, remote_port, ssh_user=''): |
182 | - """ |
183 | - Creates an SSH port-forward. |
184 | - |
185 | - This will result in localhost:local_port appearing to be |
186 | - remote_host:remote_port. |
187 | - |
188 | - :param local_port: integer port number to open at localhost |
189 | - :param remote_host: string address of remote host (no port number) |
190 | - :param remote_port: port to 'open' on the remote host |
191 | - :param ssh_user: user to log in on the remote_host as |
192 | - |
193 | - """ |
194 | - |
195 | - if not (isinstance(local_port, int) and local_port > 0): |
196 | - raise ValueError( |
197 | - 'Local port {} must be a positive integer.'.format(local_port)) |
198 | - if not (isinstance(remote_port, int) and remote_port > 0): |
199 | - raise ValueError( |
200 | - 'Remote port {} must be a positive integer.'.format(remote_port)) |
201 | - |
202 | - logging.debug('Starting SSH tunnel...') |
203 | - |
204 | - # this cmd string will be .format()ed in a few lines' time |
205 | - cmd = ['ssh'] |
206 | - |
207 | - if ssh_user: |
208 | - # ssh -l {user} ... sets the remote login username |
209 | - cmd += ['-l', ssh_user] |
210 | - |
211 | - # -L localport:localhost:remoteport forwards the port |
212 | - # -M makes SSH able to accept slave connections |
213 | - # -S sets the location of a control socket (in this case, sextant-controller |
214 | - # with a unique identifier appended, just in case we run sextant twice |
215 | - # simultaneously), so we know how to close the port again |
216 | - # -f goes into background; -N does not execute a remote command; |
217 | - # -T says to remote host that we don't want a text shell. |
218 | - cmd += ['-M', |
219 | - '-S', 'sextantcontroller{tunnel_id}'.format(tunnel_id=local_port), |
220 | - '-fNT', |
221 | - '-L', '{0}:localhost:{1}'.format(local_port, remote_port), |
222 | - remote_host] |
223 | - |
224 | - logging.debug('Running {}'.format(' '.join(cmd))) |
225 | - |
226 | - exit_code = subprocess.call(cmd) |
227 | - if exit_code: |
228 | - raise OSError('SSH setup failed with error {}'.format(exit_code)) |
229 | - |
230 | - logging.debug('SSH tunnel created.') |
231 | - |
232 | - |
233 | -def _stop_tunnel(local_port, remote_host): |
234 | - """ |
235 | - Tear down an SSH port-forward which was previously set up with start_tunnel. |
236 | - |
237 | - We use local_port as an identifier. |
238 | - :param local_port: the port on localhost we are using as the entrypoint |
239 | - :param remote_host: remote host we tunnelled into |
240 | - |
241 | - """ |
242 | - |
243 | - logging.debug('Shutting down SSH tunnel...') |
244 | - |
245 | - # ssh -O sends a command to the slave specified in -S |
246 | - cmd = ['ssh', |
247 | - '-S', 'sextantcontroller{}'.format(local_port), |
248 | - '-O', 'exit', |
249 | - '-q', # for quiet |
250 | - remote_host] |
251 | - |
252 | - # SSH has a bug on some systems which causes it to ignore the -q flag |
253 | - # meaning it prints "Exit request sent." to stderr. |
254 | - # To avoid this, we grab stderr temporarily, and see if it's that string; |
255 | - # if it is, suppress it. |
256 | - pr = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
257 | - stdout, stderr = pr.communicate() |
258 | - if stderr.rstrip() != 'Exit request sent.': |
259 | - print(stderr, file=sys.stderr) |
260 | - if pr.returncode == 0: |
261 | - logging.debug('Shut down successfully.') |
262 | - else: |
263 | - logging.warning( |
264 | - 'SSH tunnel shutdown returned error code {}'.format(pr.returncode)) |
265 | - logging.warning(stderr) |
266 | - |
267 | - |
268 | -def _is_port_used(port): |
269 | - """ |
270 | - Checks with the OS to see whether a port is open. |
271 | - |
272 | - Beware: port is passed directly to the shell. Make sure it is an integer. |
273 | - We raise ValueError if it is not. |
274 | - :param port: integer port to check for openness |
275 | - :return: bool(port is in use) |
276 | - |
277 | - """ |
278 | - |
279 | - # we follow http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python |
280 | - if not (isinstance(port, int) and port > 0): |
281 | - raise ValueError('port {} must be a positive integer.'.format(port)) |
282 | - |
283 | - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
284 | - try: |
285 | - sock.bind(('127.0.0.1', port)) |
286 | - except socket.error as e: |
287 | - if e.errno == 98: # Address already in use |
288 | - return True |
289 | - raise |
290 | - |
291 | - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
292 | - |
293 | - return False # that is, the port is not used |
294 | - |
295 | - |
296 | -def _get_unused_port(): |
297 | - """ |
298 | - Returns a port number between 10000 and 50000 which is not currently open. |
299 | - |
300 | - """ |
301 | - |
302 | - keep_going = True |
303 | - while keep_going: |
304 | - portnum = random.randint(10000, 50000) |
305 | - keep_going = _is_port_used(portnum) |
306 | - return portnum |
307 | - |
308 | + return ap.parse_args() |
309 | |
310 | def _get_host_and_port(url): |
311 | - """Given a URL as http://host:port, returns (host, port).""" |
312 | - parsed = parse.urlparse(url) |
313 | - return (parsed.hostname, parsed.port) |
314 | - |
315 | - |
316 | -def _is_localhost(host, port): |
317 | - """ |
318 | - Checks whether a host is an alias to localhost. |
319 | - |
320 | - Raises socket.gaierror if the host was not found. |
321 | - |
322 | - """ |
323 | - |
324 | - addr = socket.getaddrinfo(host, port)[0][4][0] |
325 | - |
326 | - return addr in ('127.0.0.1', '::1') |
327 | + """Given a URL as http://host:port, returns (host, port).""" |
328 | + parsed = parse.urlparse(url) |
329 | + return (parsed.hostname, parsed.port) |
330 | |
331 | |
332 | def main(): |
333 | args = parse_arguments() |
334 | - |
335 | - if args.use_ssh_tunnel.lower() == 'true': |
336 | - localport = _get_unused_port() |
337 | - |
338 | - remotehost, remoteport = _get_host_and_port(args.remote_neo4j) |
339 | - |
340 | - try: |
341 | - is_loc = _is_localhost(remotehost, remoteport) |
342 | - except socket.gaierror: |
343 | - logging.error('Server {} not found.'.format(remotehost)) |
344 | - return |
345 | - |
346 | - if is_loc: |
347 | - # we are attempting to connect to localhost anyway, so we won't |
348 | - # bother to SSH to it. |
349 | - # There may be some ways the user can trick us into trying to SSH |
350 | - # to localhost anyway, but this will do as a first pass. |
351 | - # SSHing to localhost is undesirable because on my test computer, |
352 | - # we get 'connection refused' if we try. |
353 | - args.func(args) |
354 | - |
355 | - else: # we need to SSH |
356 | - try: |
357 | - _start_tunnel(localport, remotehost, remoteport, |
358 | - ssh_user=args.ssh_user) |
359 | - except OSError as e: |
360 | - logging.error(str(e)) |
361 | - return |
362 | - except KeyboardInterrupt: |
363 | - logging.info('Halting because of user interrupt.') |
364 | - return |
365 | - |
366 | - try: |
367 | - args.display_neo4j = args.remote_neo4j |
368 | - args.remote_neo4j = 'http://localhost:{}'.format(localport) |
369 | - args.func(args) |
370 | - except KeyboardInterrupt: |
371 | - # this probably happened because we were running Sextant Web |
372 | - # and Ctrl-C'ed out of it |
373 | - logging.info('Keyboard interrupt detected. Halting.') |
374 | - pass |
375 | - |
376 | - finally: |
377 | - _stop_tunnel(localport, remotehost) |
378 | - |
379 | - else: # no need to set up the ssh, just run sextant |
380 | - args.func(args) |
381 | - |
382 | - |
383 | + remotehost, remoteport = _get_host_and_port(args.remote_neo4j) |
384 | + no_ssh_tunnel = args.no_ssh_tunnel |
385 | + connection = None |
386 | + |
387 | + try: |
388 | + conn_args = (remotehost, remoteport, no_ssh_tunnel) |
389 | + with db_api.SextantConnection(*conn_args) as connection: |
390 | + args.func(connection, args) |
391 | + except sshmanager.SSHConnectionError as e: |
392 | + print(e.message) |
393 | + |
394 | + |
395 | if __name__ == '__main__': |
396 | main() |
397 | |
398 | |
399 | === added file 'src/sextant/csvwriter.py' |
400 | --- src/sextant/csvwriter.py 1970-01-01 00:00:00 +0000 |
401 | +++ src/sextant/csvwriter.py 2014-10-23 12:33:12 +0000 |
402 | @@ -0,0 +1,152 @@ |
403 | +import logging |
404 | + |
405 | +""" |
406 | +Provide a class for writing to row-limited csv files. |
407 | +""" |
408 | +__all__ = ('CSVWriter',) |
409 | + |
410 | + |
411 | +class CSVWriter(object): |
412 | + """ |
413 | + Write to csv files, automatically opening new ones at row maximum. |
414 | + |
415 | + Provides a write(*args) method which will add a row to the currently open |
416 | + csv file (internally managed) if there is room in it, otherwise close it, |
417 | + silently open a new one and write to that. |
418 | + |
419 | + Attributes: |
420 | + base_path: |
421 | + The base path of the output files - which will have a full path |
422 | + of form "<base_path><number>.csv" |
423 | + headers: |
424 | + A list or tuple of strings which will be used as the column |
425 | + headers. Attempts to write a row of data will induce a check |
426 | + that the length of the data provided is exactly that of this |
427 | + argument. |
428 | + max_rows: |
429 | + The maximum number of rows to write in each file (including the |
430 | + header row) before opening a new file. |
431 | + |
432 | + _fmt: |
433 | + The format string which will be used to write a row to the csv |
434 | + file. Of form '{},{},...,{}\n'. |
435 | + _file: |
436 | + The currently open file. |
437 | + _file_count: |
438 | + The number of files that the CSVWriter has written to. The next |
439 | + file to be opened will have name '<base_path><_file_count>.csv' |
440 | + _row_count: |
441 | + The number of rows (including the header row) in the current file. |
442 | + _total_row_count: |
443 | + The number of rows (including the header rows) in ALL files. |
444 | + |
445 | + """ |
446 | + # Filename fmt of output files - used with .format(base_path, number). |
447 | + _file_fmt = '{}{}.csv' |
448 | + |
449 | + def __init__(self, base_path, headers, max_rows): |
450 | + """ |
451 | + Initialise the writer for writing. |
452 | + |
453 | + Arguments: |
454 | + base_path: |
455 | + The base path of the output files - which will have a full path |
456 | + of form "<base_path><number>.csv" |
457 | + headers: |
458 | + A list or tuple of strings which will be used as the column |
459 | + headers. Attempts to write a row of data will induce a check |
460 | + that the length of the data provided is exactly that of this |
461 | + argument. |
462 | + max_rows: |
463 | + The maximum number of rows to write in each file (including the |
464 | + header row) before opening a new file. |
465 | + """ |
466 | + self.base_path = base_path |
467 | + self.headers = headers |
468 | + self.max_rows = max_rows |
469 | + |
470 | + self._fmt = ','.join('{}' for h in headers) + '\n' |
471 | + |
472 | + # The number of the file we are on and the line in it. |
473 | + self._file = None |
474 | + self._file_count = 0 |
475 | + self._row_count = 0 |
476 | + |
477 | + self._total_row_count = 0 |
478 | + |
479 | + self._open_new_file() |
480 | + |
481 | + def _open_new_file(self): |
482 | + """ |
483 | + Open a new file for editing, writing the headers in the first row. |
484 | + """ |
485 | + self._close_file() |
486 | + |
487 | + path = CSVWriter._file_fmt.format(self.base_path, self._file_count) |
488 | + self._file = open(path, 'w+') |
489 | + self._file_count += 1 |
490 | + self.write(*self.headers) |
491 | + |
492 | + def _close_file(self): |
493 | + """ |
494 | + Close the current file. |
495 | + |
496 | + NOTE that this method should ALWAYS be called before attempting to read |
497 | + from the file as it ensures that all changes have been written to disk, |
498 | + not only buffered. |
499 | + """ |
500 | + if self._file and not self._file.closed: |
501 | + logging.debug('csvwriter wrote {} lines to {}' |
502 | + .format(self._row_count, self._file.name)) |
503 | + self._file.close() |
504 | + |
505 | + self._row_count = 0 |
506 | + |
507 | + def write(self, *args): |
508 | + """ |
509 | + Add a row the to current file, or to a new one if max_rows is reached. |
510 | + |
511 | + The check against max_rows is made BEFORE writing the line. |
512 | + |
513 | + Raises: |
514 | + ValueError: |
515 | + If the length of *args is not exactly the length of |
516 | + self.headers - i.e. on attempt to write too many/too few items. |
517 | + |
518 | + Arguments: |
519 | + *args: |
520 | + Strings, which will be written into the columns of the current |
521 | + open csv file. |
522 | + """ |
523 | + if not len(args) == len(self.headers): |
524 | + msg = 'Attempted to write {} entries to file {} with {} columns' |
525 | + raise ValueError(msg.format(len(args), self.base_path, |
526 | + len(self.headers))) |
527 | + |
528 | + if self._row_count == self.max_rows: |
529 | + self._close_file() |
530 | + self._open_new_file() |
531 | + |
532 | + self._file.write(self._fmt.format(*args)) |
533 | + self._row_count += 1 |
534 | + self._total_row_count += 1 |
535 | + |
536 | + def file_iter(self): |
537 | + """ |
538 | + Return an iterator over the names of the files the writer has |
539 | + written to. |
540 | + """ |
541 | + fmt = CSVWriter._file_fmt |
542 | + return (fmt.format(self.base_path, i) for i in range(self._file_count)) |
543 | + |
544 | + def finish(self): |
545 | + """ |
546 | + Flush and close the current file. If a subsequent call to self.write |
547 | + is made, a new file will be created to contain it. |
548 | + |
549 | + Return the number of files we have written to and the total number |
550 | + of lines we have written. |
551 | + """ |
552 | + self._close_file() |
553 | + return self._file_count, self._total_row_count |
554 | + |
555 | |
556 | === modified file 'src/sextant/db_api.py' |
557 | --- src/sextant/db_api.py 2014-09-03 14:10:07 +0000 |
558 | +++ src/sextant/db_api.py 2014-10-23 12:33:12 +0000 |
559 | @@ -5,208 +5,348 @@ |
560 | # ----------------------------------------- |
561 | # API to interact with a Neo4J server: upload, query and delete programs in a DB |
562 | |
563 | -__all__ = ("Validator", "AddToDatabase", "FunctionQueryResult", "Function", |
564 | +from __future__ import print_function |
565 | + |
566 | +__all__ = ("validate_query", "DBProgram", "FunctionQueryResult", "Function", |
567 | "SextantConnection") |
568 | |
569 | +from sys import stdout |
570 | + |
571 | import re # for validation of function/program names |
572 | import logging |
573 | from datetime import datetime |
574 | import os |
575 | import getpass |
576 | from collections import namedtuple |
577 | - |
578 | -from neo4jrestclient.client import GraphDatabase |
579 | -import neo4jrestclient.client as client |
580 | - |
581 | +import random |
582 | +import socket |
583 | + |
584 | +import itertools |
585 | +import subprocess |
586 | +from time import time |
587 | + |
588 | +import neo4jrestclient.client as neo4jrestclient |
589 | + |
590 | +from sshmanager import SSHManager, SSHConnectionError |
591 | +from csvwriter import CSVWriter |
592 | + |
593 | +# The directory on the local machine to which csv files will be written |
594 | +# prior to copy over to the remote server. |
595 | +TMP_DIR = '/tmp/sextant' |
596 | + |
597 | +# A function is deemed 'common' if it has more than this |
598 | +# many connections. |
599 | COMMON_CUTOFF = 10 |
600 | -# a function is deemed 'common' if it has more than this |
601 | -# many connections |
602 | - |
603 | - |
604 | -class Validator(): |
605 | - """ Sanitises/checks strings, to prevent Cypher injection attacks""" |
606 | - |
607 | - @staticmethod |
608 | - def validate(input_): |
609 | - """ |
610 | - Checks whether we can allow a string to be passed into a Cypher query. |
611 | - :param input_: the string we wish to validate |
612 | - :return: bool(the string is allowed) |
613 | - """ |
614 | - regex = re.compile(r'^[A-Za-z0-9\-:\.\$_@\*\(\)%\+,]+$') |
615 | - return bool(regex.match(input_)) |
616 | - |
617 | - @staticmethod |
618 | - def sanitise(input_): |
619 | - """ |
620 | - Strips harmful characters from the given string. |
621 | - :param input_: string to sanitise |
622 | - :return: the sanitised string |
623 | - """ |
624 | - return re.sub(r'[^\.\-_a-zA-Z0-9]+', '', input_) |
625 | - |
626 | - |
627 | -class AddToDatabase(): |
628 | - """Updates the database, adding functions/calls to a given program""" |
629 | - |
630 | - def __init__(self, program_name='', sextant_connection=None, |
631 | - uploader='', uploader_id='', date=None): |
632 | - """ |
633 | - Object which can be used to add functions and calls to a new program |
634 | - :param program_name: the name of the new program to be created |
635 | - (must already be validated against Validator) |
636 | - :param sextant_connection: the SextantConnection to use for connections |
637 | - :param uploader: string identifier of user who is uploading |
638 | - :param uploader_id: string Unix user-id of logged-in user |
639 | - :param date: string date of today |
640 | - """ |
641 | - # program_name must be alphanumeric, to avoid injection attacks easily |
642 | - if not Validator.validate(program_name): |
643 | - return |
644 | + |
645 | + |
646 | + |
647 | +def set_common_cutoff(common_def): |
648 | + """ |
649 | + Sets the number of incoming connections at which we deem a function 'common' |
650 | + Default is 10 (which is used if this method is never called). |
651 | + :param common_def: number of incoming connections |
652 | + """ |
653 | + global COMMON_CUTOFF |
654 | + COMMON_CUTOFF = common_def |
655 | + |
656 | + |
657 | +def validate_query(string): |
658 | + """ |
659 | + Checks whether we can allow a string to be passed into a Cypher query. |
660 | + :param string: the string we wish to validate |
661 | + :return: bool(the string is allowed) |
662 | + """ |
663 | + regex = re.compile(r'^[A-Za-z0-9\-:\.\$_@\*\(\)%\+,]+$') |
664 | + return bool(regex.match(string)) |
665 | + |
666 | + |
667 | +class DBProgram(object): |
668 | + """ |
669 | + Representation of a program in the database. |
670 | + |
671 | + Provides add_function and add_call methods which locally register functions |
672 | + and calls. The commit method uploads everything to the database. |
673 | + |
674 | + Attributes: |
675 | + uploader, uploader_id, program_name, date: |
676 | + As in __init__. |
677 | + |
678 | + _conn: |
679 | + The SextantConnection object managing the database connection. |
680 | + _ssh: |
681 | + The SSHManager object belonging to the SextantConnection. |
682 | + _db: |
683 | + The database object belonging to the SextantConnection. |
684 | + |
685 | + _tmp_dir: |
686 | + The user-specific location of the local temporary directory. |
687 | + |
688 | + func_writer: |
689 | + A CSVWriter object which manages the csv files containing the |
690 | + list of functions in the program. |
691 | + call_writer: |
692 | + A CSVWriter object which manages the csv files containing the |
693 | + list of function calls in the program. |
694 | + |
695 | + add_func_query: |
696 | + A string for the cypher query used to create functions from a csv |
697 | + file. |
698 | + add_call_query: |
699 | + A string for the cypher query used to create funciton calls from |
700 | + a csv file. |
701 | + add_program_query: |
702 | + A string for the cypher query used to create the program node. |
703 | + """ |
704 | + |
705 | + def __init__(self, connection, program_name, uploader, uploader_id, date): |
706 | + """ |
707 | + Initialise the database program. |
708 | + |
709 | + A local temporary folder is created at 'TMP_DIR-<user_name>'. |
710 | + When functions or calls are added via the add_function/call methods, |
711 | + they are registered in csv files which are stored in this directory. |
712 | + |
713 | + Committing the program copies these files to the neo4j server and |
714 | + cleans the local tmp folder. |
715 | + |
716 | + Raises: |
717 | + ValueError: |
718 | + If the program_name is not alphanumeric. |
719 | + CommandError: |
720 | + If the command to create the temporary directory failed. |
721 | + |
722 | + Arguments: |
723 | + connection: |
724 | + The SextantConnection object which manages the connection to |
725 | + the database. |
726 | + program_name: |
727 | + The name to register the program under in the database. Must be |
728 | + alphanumeric. |
729 | + uploader: |
730 | + The name of the user who uploaded the program. |
731 | + uploader_id: |
732 | + A numeric id of the user who uploaded the program. |
733 | + date: |
734 | + A string representing the upload date. |
735 | + """ |
736 | + # Ensure an alphanumeric program name. |
737 | + if not validate_query(program_name): |
738 | + raise ValueError('program name must be alphanumeric, got: {}' |
739 | + .format(program_name)); |
740 | + |
741 | + self.uploader = uploader |
742 | + self.uploader_id = uploader_id |
743 | |
744 | self.program_name = program_name |
745 | - self.parent_database_connection = sextant_connection |
746 | - self._functions = {} |
747 | - self._funcs_tx = None # transaction for uploading functions |
748 | - self._calls_tx = None # transaction for uploading relationships |
749 | - |
750 | - if self.parent_database_connection: |
751 | - # we'll locally use db for short |
752 | - db = self.parent_database_connection._db |
753 | - |
754 | - parent_function = db.nodes.create(name=program_name, |
755 | - type='program', |
756 | - uploader=uploader, |
757 | - uploader_id=uploader_id, |
758 | - date=date) |
759 | - self._parent_id = parent_function.id |
760 | - |
761 | - self._funcs_tx = db.transaction(using_globals=False, for_query=True) |
762 | - self._calls_tx = db.transaction(using_globals=False, for_query=True) |
763 | - |
764 | - self._connections = [] |
765 | - |
766 | - @staticmethod |
767 | - def _get_display_name(function_name): |
768 | - """ |
769 | - Gets the name we will display to the user for this function name. |
770 | - |
771 | - For instance, if function_name were __libc_start_main@plt, we would |
772 | - return ("__libc_start_main", "plt_stub"). The returned function type is |
773 | - currently one of "plt_stub", "function_pointer" or "normal". |
774 | - |
775 | - :param function_name: the name straight from objdump of a function |
776 | - :return: ("display name", "function type") |
777 | - |
778 | - """ |
779 | - |
780 | - if function_name[-4:] == "@plt": |
781 | - display_name = function_name[:-4] |
782 | - function_group = "plt_stub" |
783 | - elif function_name[:20] == "_._function_pointer_": |
784 | - display_name = function_name |
785 | - function_group = "function_pointer" |
786 | - else: |
787 | - display_name = function_name |
788 | - function_group = "normal" |
789 | - |
790 | - return display_name, function_group |
791 | - |
792 | - def add_function(self, function_name): |
793 | - """ |
794 | - Adds a function to the program, ready to be sent to the remote database. |
795 | - If the function name is already in use, this method effectively does |
796 | - nothing and returns True. |
797 | - |
798 | - :param function_name: a string which must be alphanumeric |
799 | - :return: True if the request succeeded, False otherwise |
800 | - """ |
801 | - if not Validator.validate(function_name): |
802 | - return False |
803 | - if self.class_contains_function(function_name): |
804 | - return True |
805 | - |
806 | - display_name, function_group = self._get_display_name(function_name) |
807 | - |
808 | - query = ('START n = node({}) ' |
809 | - 'CREATE (n)-[:subject]->(m:func {{type: "{}", name: "{}"}}) ' |
810 | - 'RETURN m.name, id(m)') |
811 | - query = query.format(self._parent_id, function_group, display_name) |
812 | - |
813 | - self._funcs_tx.append(query) |
814 | - |
815 | - self._functions[function_name] = function_name |
816 | - |
817 | - return True |
818 | - |
819 | - def class_contains_function(self, function_to_find): |
820 | - """ |
821 | - Checks whether we contain a function with a given name. |
822 | - :param function_to_find: string name of the function we wish to look up |
823 | - :return: bool(the function exists in this AddToDatabase) |
824 | - """ |
825 | - return function_to_find in self._functions |
826 | - |
827 | - def class_contains_call(self, function_calling, function_called): |
828 | - """ |
829 | - Checks whether we contain a call between the two named functions. |
830 | - :param function_calling: string name of the calling-function |
831 | - :param function_called: string name of the called function |
832 | - :return: bool(function_calling calls function_called in us) |
833 | - """ |
834 | - return (function_calling, function_called) in self._connections |
835 | - |
836 | - def add_function_call(self, fn_calling, fn_called): |
837 | - """ |
838 | - Adds a function call to the program, ready to be sent to the database. |
839 | - Effectively does nothing if there is already a function call between |
840 | - these two functions. |
841 | - Function names must be alphanumeric for easy security purposes; |
842 | - returns False if they fail validation. |
843 | - :param fn_calling: the name of the calling-function as a string. |
844 | - It should already exist in the AddToDatabase; if it does not, |
845 | - this method will create a stub for it. |
846 | - :param fn_called: name of the function called by fn_calling. |
847 | - If it does not exist, we create a stub representation for it. |
848 | - :return: True if successful, False otherwise |
849 | - """ |
850 | - if not all((Validator.validate(fn_calling), |
851 | - Validator.validate(fn_called))): |
852 | - return False |
853 | - |
854 | - if not self.class_contains_function(fn_called): |
855 | - self.add_function(fn_called) |
856 | - if not self.class_contains_function(fn_calling): |
857 | - self.add_function(fn_calling) |
858 | - |
859 | - if not self.class_contains_call(fn_calling, fn_called): |
860 | - self._connections.append((fn_calling, fn_called)) |
861 | - |
862 | - return True |
863 | + self.date = date |
864 | + |
865 | + self._conn = connection |
866 | + self._ssh = connection._ssh |
867 | + self._db = connection._db |
868 | + |
869 | + self._tmp_dir = '{}-{}'.format(TMP_DIR, getpass.getuser()) |
870 | + |
871 | + # Make the local tmp file - csv files will be written into here. |
872 | + try: |
873 | + os.makedirs(self._tmp_dir) |
874 | + except OSError as e: |
875 | + if e.errno == os.errno.EEXIST: # File already exists. |
876 | + pass |
877 | + else: |
878 | + raise e |
879 | + |
880 | + |
881 | + tmp_path = os.path.join(self._tmp_dir, '{}_{{}}'.format(program_name)) |
882 | + |
883 | + self.func_writer = CSVWriter(tmp_path.format('funcs'), |
884 | + headers=['name', 'type'], |
885 | + max_rows=5000) |
886 | + self.call_writer = CSVWriter(tmp_path.format('calls'), |
887 | + headers=['caller', 'callee'], |
888 | + max_rows=5000) |
889 | + |
890 | + # Define the queries we use to upload the functions and calls. |
891 | + self.add_func_query = (' USING PERIODIC COMMIT 250' |
892 | + ' LOAD CSV WITH HEADERS FROM "file:{}" AS line' |
893 | + ' WITH line, toInt(line.id) as lineid' |
894 | + ' MATCH (n:program {{name: "{}"}})' |
895 | + ' CREATE (n)-[:subject]->(m:func {{name: line.name,' |
896 | + ' id: lineid, type: line.type}})') |
897 | + |
898 | + self.add_call_query = (' USING PERIODIC COMMIT 250' |
899 | + ' LOAD CSV WITH HEADERS FROM "file:{}" AS line' |
900 | + ' MATCH (p:program {{name: "{}"}})' |
901 | + ' MATCH (p)-[:subject]->(n:func {{name: line.caller}})' |
902 | + ' USING INDEX n:func(name)' |
903 | + ' MATCH (p)-[:subject]->(m:func {{name: line.callee}})' |
904 | + ' USING INDEX m:func(name)' |
905 | + ' CREATE (n)-[r:calls]->(m)') |
906 | + |
907 | + self.add_program_query = ('CREATE (p:program {{name: "{}", uploader: "{}", ' |
908 | + ' uploader_id: "{}", date: "{}",' |
909 | + ' function_count: {}, call_count: {}}})') |
910 | + |
911 | + |
912 | + def __enter__(self): |
913 | + """ |
914 | + Allow DBProgram to be used as a context manager. |
915 | + """ |
916 | + return self |
917 | + |
918 | + def __exit__(self, etype, evalue, etrace): |
919 | + """ |
920 | + Make sure that all files are properly closed. |
921 | + """ |
922 | + self.func_writer.finish() |
923 | + self.call_writer.finish() |
924 | + |
925 | + # Propagate the error if there is one. |
926 | + return False if etype is not None else True |
927 | + |
928 | + def add_function(self, name, typ='normal'): |
929 | + """ |
930 | + Add a function. |
931 | + |
932 | + Arguments: |
933 | + name: |
934 | + The name of the function. |
935 | + typ: |
936 | + The type of the function, may be any string, but standard types |
937 | + are: |
938 | + normal: we have the disassembly for this function |
939 | + stub: we have the name but not the disassembly - usually |
940 | + an imported library function. |
941 | + pointer: we know only that the function exists, not its |
942 | + name or details. |
943 | + """ |
944 | + self.func_writer.write(name, typ) |
945 | + |
946 | + def add_call(self, caller, callee): |
947 | + """ |
948 | + Add a function call. |
949 | + |
950 | + Arguments: |
951 | + caller: |
952 | + The name of the function making the call. |
953 | + callee: |
954 | + The name of the function called. |
955 | + """ |
956 | + self.call_writer.write(caller, callee) |
957 | + |
958 | + |
959 | + def _copy_local_to_remote_tmp_dir(self): |
960 | + """ |
961 | + Move local tmp files to the server ready for upload. |
962 | + |
963 | + Return a tuple of iterators, the first over the paths on the remote |
964 | + machine of the function files, and the second over the paths of the |
965 | + call files. |
966 | + """ |
967 | + print('Sending files to remote server...', end='') |
968 | + stdout.flush() |
969 | + remote_funcs = self._ssh.send_to_tmp_dir(self.func_writer.file_iter()) |
970 | + remote_calls = self._ssh.send_to_tmp_dir(self.call_writer.file_iter()) |
971 | + print('finished.') |
972 | + return remote_funcs, remote_calls |
973 | + |
974 | + def _clean_tmp_files(self, remote_paths): |
975 | + """ |
976 | + Delete temporary files on the local and remote machine. |
977 | + |
978 | + Arguments: |
979 | + remote_paths: |
980 | + A list of the paths of the remote fils. |
981 | + """ |
982 | + print('Cleaning temporary files...', end='') |
983 | + file_paths = list(itertools.chain(self.func_writer.file_iter(), |
984 | + self.call_writer.file_iter())) |
985 | + |
986 | + for path in file_paths: |
987 | + os.remove(path) |
988 | + |
989 | + os.rmdir(self._tmp_dir) |
990 | + |
991 | + try: |
992 | + # If the parent sextant temp folder is empty, remove it. |
993 | + os.rmdir(TMP_DIR) |
994 | + except: |
995 | + # There is other stuff in TMP_DIR (i.e. from other users), so |
996 | + # leave it. |
997 | + pass |
998 | + |
999 | + self._ssh.remove_from_tmp_dir(remote_paths) |
1000 | + |
1001 | + print('done.') |
1002 | + |
1003 | + def _create_db_constraints(self): |
1004 | + """ |
1005 | + Create indexes in the database on program and function names. |
1006 | + |
1007 | + The program name index is a constraint, which will also garuantee the |
1008 | + uniqueness of program names. |
1009 | + """ |
1010 | + # Prepare a transaction object which we use to execute cypher queries. |
1011 | + tx = self._db.transaction(using_globals=False, for_query=True) |
1012 | + |
1013 | + tx.append('CREATE CONSTRAINT ON (p:program) ASSERT p.name IS UNIQUE') |
1014 | + tx.append('CREATE INDEX ON :func(name)') |
1015 | + |
1016 | + # Apply the transaction. |
1017 | + tx.commit() |
1018 | |
1019 | def commit(self): |
1020 | """ |
1021 | - Call this when you are finished with the object. |
1022 | - Changes are not synced to the remote database until this is called. |
1023 | + Insert the program into the database. |
1024 | + |
1025 | + Move the local temp files created by our func_writer and call_writer |
1026 | + to the database server's temp directory. From there use cypher queries |
1027 | + to upload them into the database, before cleaning them up. |
1028 | """ |
1029 | - functions = self._funcs_tx.commit() # send off the function names |
1030 | - |
1031 | - # now functions is a list of QuerySequence objects, which each have a |
1032 | - # .elements property which produces [['name', id]] |
1033 | - |
1034 | - id_funcs = dict([seq.elements[0] for seq in functions]) |
1035 | - logging.info('Functions uploaded. Uploading calls...') |
1036 | - |
1037 | - # so id_funcs is a dict with id_funcs['name'] == id |
1038 | - for call in self._connections: |
1039 | - query = ('MATCH n WHERE id(n) = {} ' |
1040 | - 'MATCH m WHERE id(m) = {} ' |
1041 | - 'CREATE (n)-[:calls]->(m)') |
1042 | - query = query.format(id_funcs[self._get_display_name(call[0])[0]], |
1043 | - id_funcs[self._get_display_name(call[1])[0]]) |
1044 | - self._calls_tx.append(query) |
1045 | - |
1046 | - self._calls_tx.commit() |
1047 | + # Ensure that the most recent files are flushed and closed. |
1048 | + func_file_count, func_line_count = self.func_writer.finish() |
1049 | + call_file_count, call_line_count = self.call_writer.finish() |
1050 | + |
1051 | + # Account for the header line at the top of each file. |
1052 | + func_count = func_line_count - func_file_count |
1053 | + call_count = call_line_count - call_file_count |
1054 | + |
1055 | + # Get the remote path names as iterators, then make lists of them |
1056 | + # so that we can iterate over them more than once. |
1057 | + remote_f_iter, remote_c_iter = self._copy_local_to_remote_tmp_dir() |
1058 | + remote_funcs, remote_calls = map(list, (remote_f_iter, remote_c_iter)) |
1059 | + |
1060 | + # Create the indexes and constraints in the database. |
1061 | + self._create_db_constraints() |
1062 | + |
1063 | + |
1064 | + try: |
1065 | + tx = self._db.transaction(using_globals=False, for_query=True) |
1066 | + |
1067 | + # Create the program node in the database. |
1068 | + tx.append(self.add_program_query.format(self.program_name, self.uploader, |
1069 | + self.uploader_id, self.date, |
1070 | + func_count, call_count)) |
1071 | + tx.commit() |
1072 | + |
1073 | + # Create the functions. |
1074 | + for files, query, descr in zip((remote_funcs, remote_calls), |
1075 | + (self.add_func_query, self.add_call_query), |
1076 | + ('funcs', 'calls')): |
1077 | + start = time() |
1078 | + for i, path in enumerate(files): |
1079 | + completed = int(100*float(i+1)/len(files)) |
1080 | + |
1081 | + print('\rUploading {}: {}%'.format(descr, completed), end='') |
1082 | + stdout.flush() |
1083 | + |
1084 | + tx.append(query.format(path, self.program_name)) |
1085 | + tx.commit() |
1086 | + end = time() |
1087 | + print(' done.') |
1088 | + |
1089 | + finally: |
1090 | + # Cleanup temporary folders |
1091 | + self._clean_tmp_files(remote_funcs + remote_calls) |
1092 | |
1093 | |
1094 | class FunctionQueryResult: |
1095 | @@ -219,7 +359,7 @@ |
1096 | self._update_common_functions() |
1097 | |
1098 | def __eq__(self, other): |
1099 | - # we make a dictionary so that we can perform easy comparison |
1100 | + # We make a dictionary so that we can perform easy comparison. |
1101 | selfdict = {func.name: func for func in self.functions} |
1102 | otherdict = {func.name: func for func in other.functions} |
1103 | |
1104 | @@ -243,20 +383,20 @@ |
1105 | if rest_output is None or not rest_output.elements: |
1106 | return [] |
1107 | |
1108 | - # how we store this is: a dict |
1109 | + # How we store this is: a dict |
1110 | # with keys 'functionname' |
1111 | # and values [the function object we will use, |
1112 | # and a set of (function names this function calls), |
1113 | - # and numeric ID of this node in the Neo4J database] |
1114 | + # and numeric ID of this node in the Neo4J database]. |
1115 | |
1116 | result = {} |
1117 | |
1118 | - # initial pass for names of functions |
1119 | + # Initial pass for names of functions. |
1120 | |
1121 | - # if the following assertion failed, we've probably called db.query |
1122 | + # If the following assertion failed, we've probably called db.query |
1123 | # to get it to not return client.Node objects, which is wrong. |
1124 | # we attempt to handle this a bit later; this should never arise, but |
1125 | - # we can cope with it happening in some cases, like the test suite |
1126 | + # we can cope with it happening in some cases, like the test suite. |
1127 | |
1128 | if type(rest_output.elements) is not list: |
1129 | logging.warning('Not a list: {}'.format(type(rest_output.elements))) |
1130 | @@ -264,11 +404,12 @@ |
1131 | for node_list in rest_output.elements: |
1132 | assert(isinstance(node_list, list)) |
1133 | for node in node_list: |
1134 | - if isinstance(node, client.Node): |
1135 | + if isinstance(node, neo4jrestclient.Node): |
1136 | name = node.properties['name'] |
1137 | node_id = node.id |
1138 | node_type = node.properties['type'] |
1139 | - else: # this is the handling we mentioned earlier; |
1140 | + else: |
1141 | + # This is the handling we mentioned earlier; |
1142 | # we are a dictionary instead of a list, as for some |
1143 | # reason we've returned Raw rather than Node data. |
1144 | # We should never reach this code, but just in case. |
1145 | @@ -283,7 +424,7 @@ |
1146 | set(), |
1147 | node_id] |
1148 | |
1149 | - # end initialisation of names-dictionary |
1150 | + # End initialisation of names-dictionary. |
1151 | |
1152 | if self._parent_db_connection is not None: |
1153 | # This is the normal case, of extracting results from a server. |
1154 | @@ -301,7 +442,7 @@ |
1155 | logging.debug('exec') |
1156 | results = new_tx.execute() |
1157 | |
1158 | - # results is a list of query results, each of those being a list of |
1159 | + # Results is a list of query results, each of those being a list of |
1160 | # calls. |
1161 | |
1162 | for call_list in results: |
1163 | @@ -315,7 +456,7 @@ |
1164 | # recall: set union is denoted by | |
1165 | |
1166 | else: |
1167 | - # we don't have a parent database connection. |
1168 | + # We don't have a parent database connection. |
1169 | # This has probably arisen because we created this object from a |
1170 | # test suite, or something like that. |
1171 | for node in rest_output.elements: |
1172 | @@ -353,19 +494,10 @@ |
1173 | func_list = [func for func in self.functions if func.name == name] |
1174 | return None if len(func_list) == 0 else func_list[0] |
1175 | |
1176 | - |
1177 | -def set_common_cutoff(common_def): |
1178 | - """ |
1179 | - Sets the number of incoming connections at which we deem a function 'common' |
1180 | - Default is 10 (which is used if this method is never called). |
1181 | - :param common_def: number of incoming connections |
1182 | - """ |
1183 | - global COMMON_CUTOFF |
1184 | - COMMON_CUTOFF = common_def |
1185 | - |
1186 | - |
1187 | class Function(object): |
1188 | - """Represents a function which might appear in a FunctionQueryResult.""" |
1189 | + """ |
1190 | + Represents a function which might appear in a FunctionQueryResult. |
1191 | + """ |
1192 | |
1193 | def __eq__(self, other): |
1194 | funcs_i_call_list = {func.name for func in self.functions_i_call} |
1195 | @@ -393,11 +525,11 @@ |
1196 | self.name = function_name |
1197 | self.is_common = False |
1198 | self._number_calling_me = 0 |
1199 | - # care: _number_calling_me is not automatically updated, except by |
1200 | + # Care: _number_calling_me is not automatically updated, except by |
1201 | # any invocation of FunctionQueryResult._update_common_functions. |
1202 | |
1203 | |
1204 | -class SextantConnection: |
1205 | +class SextantConnection(object): |
1206 | """ |
1207 | RESTful connection to a remote database. |
1208 | It can be used to create/delete/query programs. |
1209 | @@ -406,56 +538,214 @@ |
1210 | ProgramWithMetadata = namedtuple('ProgramWithMetadata', |
1211 | ['uploader', 'uploader_id', |
1212 | 'program_name', 'date', |
1213 | - 'number_of_funcs']) |
1214 | - |
1215 | - def __init__(self, url): |
1216 | - self.url = url |
1217 | - self._db = GraphDatabase(url) |
1218 | - |
1219 | - def new_program(self, name_of_program): |
1220 | + 'number_of_funcs', 'number_of_calls']) |
1221 | + |
1222 | + @staticmethod |
1223 | + def _is_localhost(host, port): |
1224 | + """ |
1225 | + Checks whether a host is an alias to localhost. |
1226 | + |
1227 | + Raises socket.gaierror if the host was not found. |
1228 | + """ |
1229 | + addr = socket.getaddrinfo(host, port)[0][4][0] |
1230 | + return addr in ('127.0.0.1', '::1') |
1231 | + |
1232 | + @staticmethod |
1233 | + def _is_port_used(port): |
1234 | + """ |
1235 | + Checks with the OS to see whether a port is open. |
1236 | + |
1237 | + Beware: port is passed directly to the shell. Make sure it is an integer. |
1238 | + We raise ValueError if it is not. |
1239 | + :param port: integer port to check for openness |
1240 | + :return: bool(port is in use) |
1241 | + """ |
1242 | + result = False |
1243 | + |
1244 | + # We follow: |
1245 | + # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python |
1246 | + if not (isinstance(port, int) and port > 0): |
1247 | + raise ValueError('port {} must be a positive integer.'.format(port)) |
1248 | + |
1249 | + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
1250 | + try: |
1251 | + sock.bind(('127.0.0.1', port)) |
1252 | + except socket.error as e: |
1253 | + if e.errno == os.errno.EADDRINUSE: |
1254 | + result = True |
1255 | + else: |
1256 | + raise |
1257 | + |
1258 | + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
1259 | + |
1260 | + return result # that is, the port is not used |
1261 | + |
1262 | + @staticmethod |
1263 | + def _get_unused_port(): |
1264 | + """ |
1265 | + Returns a port number between 10000 and 50000 which is not currently open. |
1266 | + """ |
1267 | + |
1268 | + keep_going = True |
1269 | + while keep_going: |
1270 | + portnum = random.randint(10000, 50000) |
1271 | + keep_going = SextantConnection._is_port_used(portnum) |
1272 | + return portnum |
1273 | + |
1274 | + |
1275 | + def __enter__(self): |
1276 | + return self |
1277 | + |
1278 | + def __exit__(self, etype, evalue, etrace): |
1279 | + self.close() |
1280 | + return False if etype is not None else True |
1281 | + |
1282 | + |
1283 | + def __init__(self, remotehost, remoteport, no_ssh_tunnel=False): |
1284 | + """ |
1285 | + Initialise the database and ssh connections. |
1286 | + |
1287 | + Arguments: |
1288 | + remotehost: |
1289 | + The remote host name to connect to. |
1290 | + remoteport: |
1291 | + The port number to connect to on the remote host. |
1292 | + no_ssh_tunnel: |
1293 | + Disables the SSHManager if True. Prevents program upload. |
1294 | + """ |
1295 | + |
1296 | + self.remote_host = remotehost |
1297 | + self.remote_port = remoteport |
1298 | + |
1299 | + |
1300 | + self._no_ssh_tunnel = no_ssh_tunnel |
1301 | + self._ssh = None |
1302 | + self._db = None |
1303 | + |
1304 | + self.open() |
1305 | + |
1306 | + def open(self): |
1307 | + local_port = SextantConnection._get_unused_port() |
1308 | + is_localhost = SextantConnection._is_localhost(self.remote_host, self.remote_port) |
1309 | + |
1310 | + if self._no_ssh_tunnel and not is_localhost: |
1311 | + raise SSHConnectionError('Cannot connect to the remote database ' |
1312 | + 'without an ssh connection.') |
1313 | + else: |
1314 | + # Either we are making an ssh tunnel or we are contacting localhost. |
1315 | + self._ssh = SSHManager(local_port, |
1316 | + self.remote_host, |
1317 | + self.remote_port, |
1318 | + is_localhost=is_localhost) |
1319 | + |
1320 | + port = self.remote_port if is_localhost else local_port |
1321 | + url = 'http://localhost:{}'.format(port) |
1322 | + |
1323 | + self._db = neo4jrestclient.GraphDatabase(url) |
1324 | + |
1325 | + def close(self): |
1326 | + """ |
1327 | + Close the ssh connection to clean up its resources. |
1328 | + """ |
1329 | + if self._ssh: |
1330 | + self._ssh.close() |
1331 | + |
1332 | + def new_program(self, program_name): |
1333 | """ |
1334 | Request that the remote database create a new program with the given name. |
1335 | This procedure will create a new program remotely; you can manipulate |
1336 | - that program using the returned AddToDatabase object. |
1337 | + that program using the returned DBProgram object. |
1338 | The name can appear in the database already, but this is not recommended |
1339 | because then delete_program will not know which to delete. Check first |
1340 | using self.check_program_exists. |
1341 | - The name specified must pass Validator.validate()ion; this is a measure |
1342 | + The name specified must pass validate_query()ion; this is a measure |
1343 | to prevent Cypher injection attacks. |
1344 | - :param name_of_program: string program name |
1345 | - :return: AddToDatabase instance if successful |
1346 | + :param program_name: string program name |
1347 | + :return: DBProgram instance if successful |
1348 | """ |
1349 | |
1350 | - if not Validator.validate(name_of_program): |
1351 | - raise ValueError( |
1352 | - "{} is not a valid program name".format(name_of_program)) |
1353 | + if not validate_query(program_name): |
1354 | + raise ValueError("{} is not a valid program name" |
1355 | + .format(program_name)) |
1356 | |
1357 | uploader = getpass.getuser() |
1358 | uploader_id = os.getuid() |
1359 | - |
1360 | - return AddToDatabase(sextant_connection=self, |
1361 | - program_name=name_of_program, |
1362 | - uploader=uploader, uploader_id=uploader_id, |
1363 | - date=str(datetime.now())) |
1364 | - |
1365 | - def delete_program(self, name_of_program): |
1366 | + timestr = datetime.now().strftime('%Y-%m-%d %H:%M:%S') |
1367 | + |
1368 | + return DBProgram(self, program_name, uploader, |
1369 | + uploader_id, date=timestr) |
1370 | + |
1371 | + def delete_program(self, program_name): |
1372 | """ |
1373 | Request that the remote database delete a specified program. |
1374 | - :param name_of_program: a string which must be alphanumeric only |
1375 | + :param program_name: a string which must be alphanumeric only |
1376 | :return: bool(request succeeded) |
1377 | """ |
1378 | - if not Validator.validate(name_of_program): |
1379 | - return False |
1380 | - |
1381 | - q = """MATCH (n) WHERE n.name= "{}" AND n.type="program" |
1382 | - OPTIONAL MATCH (n)-[r]-(b) OPTIONAL MATCH (b)-[rel]-() |
1383 | - DELETE b,rel DELETE n, r""".format(name_of_program) |
1384 | - |
1385 | - self._db.query(q) |
1386 | + if not program_name in self.get_program_names(): |
1387 | + print('No program `{}` in the database'.format(program_name)) |
1388 | + return True |
1389 | + else: |
1390 | + print('Deleting `{}` from the database. ' |
1391 | + 'This may take some time for larger programs.' |
1392 | + .format(program_name)) |
1393 | + |
1394 | + start = time() |
1395 | + tx = self._db.transaction(using_globals=False, for_query=True) |
1396 | + |
1397 | + count_query = (' MATCH (p:program {{name: "{}"}})' |
1398 | + ' RETURN p.function_count, p.call_count' |
1399 | + .format(program_name)) |
1400 | + |
1401 | + tx.append(count_query) |
1402 | + func_count, call_count = tx.commit()[0].elements[0] |
1403 | + |
1404 | + del_call_query = ('OPTIONAL MATCH (p:program {{name: "{}"}})' |
1405 | + '-[:subject]->(f:func)-[c:calls]->()' |
1406 | + ' WITH c LIMIT 5000 DELETE c RETURN count(distinct(c))' |
1407 | + .format(program_name)) |
1408 | + |
1409 | + del_func_query = ('OPTIONAL MATCH (p:program {{name: "{}"}})' |
1410 | + '-[s:subject]->(f:func)' |
1411 | + ' WITH s, f LIMIT 5000 DELETE s, f RETURN count(f)' |
1412 | + .format(program_name)) |
1413 | + |
1414 | + del_prog_query = ('MATCH (p:program {{name: "{}"}}) DELETE p' |
1415 | + .format(program_name)) |
1416 | + |
1417 | + # Delete calls first, a node may not be deleted until all relationships |
1418 | + # referencing it are deleted. |
1419 | + for count, query, descr in zip((call_count, func_count), |
1420 | + (del_call_query, del_func_query), |
1421 | + ('calls', 'funcs')): |
1422 | + # Change tracks whether the last delete did anything. We would |
1423 | + # like to use: while done < count: ..., but if the program has |
1424 | + # already been partially deleted then this will never terminate. |
1425 | + # Furthermore, if there are no functions or no calls, the while |
1426 | + # loop will be appropriately skipped. |
1427 | + change = count |
1428 | + done = 0 |
1429 | + while change: |
1430 | + completed = int(100 * float(done)/count) |
1431 | + print('\rDeleting {}: {}%'.format(descr, completed), end='') |
1432 | + stdout.flush() |
1433 | + |
1434 | + tx.append(query) |
1435 | + change = tx.commit()[0].elements[0][0] |
1436 | + done += change |
1437 | + if done: |
1438 | + print(' done.') |
1439 | + |
1440 | + # Delete the program node. |
1441 | + tx.append(del_prog_query) |
1442 | + tx.commit() |
1443 | + |
1444 | + end = time() |
1445 | + print('Finished in {:.2f}s.'.format(end - start)) |
1446 | |
1447 | return True |
1448 | |
1449 | - def _execute_query(self, prog_name='', query=''): |
1450 | + |
1451 | + def _execute_query(self, prog_name, query): |
1452 | """ |
1453 | Executes a Cypher query against the remote database. |
1454 | Note that this returns a FunctionQueryResult, so is unsuitable for any |
1455 | @@ -468,7 +758,7 @@ |
1456 | :param query: verbatim query we wish the server to execute |
1457 | :return: a FunctionQueryResult corresponding to the server's output |
1458 | """ |
1459 | - rest_output = self._db.query(query, returns=client.Node) |
1460 | + rest_output = self._db.query(query, returns=neo4jrestclient.Node) |
1461 | |
1462 | return FunctionQueryResult(parent_db=self._db, |
1463 | program_name=prog_name, |
1464 | @@ -481,12 +771,11 @@ |
1465 | method which requires a program-name input. |
1466 | :return: a list of function-name strings. |
1467 | """ |
1468 | - q = """MATCH (n) WHERE n.type = "program" RETURN n.name""" |
1469 | + q = 'MATCH (n:program) RETURN n.name' |
1470 | program_names = self._db.query(q, returns=str).elements |
1471 | |
1472 | - result = [el[0] for el in program_names] |
1473 | + return set(el[0] for el in program_names) |
1474 | |
1475 | - return set(result) |
1476 | |
1477 | def programs_with_metadata(self): |
1478 | """ |
1479 | @@ -498,27 +787,28 @@ |
1480 | |
1481 | """ |
1482 | |
1483 | - q = ("MATCH (base) WHERE base.type = 'program' " |
1484 | - "MATCH (base)-[:subject]->(n)" |
1485 | - "RETURN base.uploader, base.uploader_id, base.name, base.date, count(n)") |
1486 | + q = (' MATCH (p:program)' |
1487 | + ' RETURN p.uploader, p.uploader_id, p.name, p.date,' |
1488 | + ' p.function_count, p.call_count') |
1489 | result = self._db.query(q) |
1490 | return {self.ProgramWithMetadata(*res) for res in result} |
1491 | |
1492 | def check_program_exists(self, program_name): |
1493 | """ |
1494 | Execute query to check whether a program with the given name exists. |
1495 | - Returns False if the program_name fails validation against Validator. |
1496 | + Returns False if the program_name fails validation (i.e. is possibly |
1497 | + unsafe as a string in a cypher query). |
1498 | :return: bool(the program exists in the database). |
1499 | """ |
1500 | |
1501 | - if not Validator.validate(program_name): |
1502 | + if not validate_query(program_name): |
1503 | return False |
1504 | |
1505 | - q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' " |
1506 | - "RETURN count(base)").format(program_name) |
1507 | + q = ('MATCH (p:program {{name: "{}"}}) RETURN p LIMIT 1' |
1508 | + .format(program_name)) |
1509 | |
1510 | - result = self._db.query(q, returns=int) |
1511 | - return result.elements[0][0] > 0 |
1512 | + result = self._db.query(q, returns=neo4jrestclient.Node) |
1513 | + return bool(result) |
1514 | |
1515 | def check_function_exists(self, program_name, function_name): |
1516 | """ |
1517 | @@ -529,18 +819,18 @@ |
1518 | :param function_name: string name of the function to check for existence |
1519 | :return: bool(names validate correctly, and function exists in program) |
1520 | """ |
1521 | - if not self.check_program_exists(program_name): |
1522 | - return False |
1523 | - |
1524 | - if not Validator.validate(program_name): |
1525 | - return False |
1526 | - |
1527 | - q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program'" |
1528 | - "MATCH (base)-[r:subject]->(m) WHERE m.name = '{}'" |
1529 | - "RETURN count(m)").format(program_name, function_name) |
1530 | - |
1531 | - result = self._db.query(q, returns=int) |
1532 | - return result.elements[0][0] > 0 |
1533 | + if not validate_query(program_name): |
1534 | + return False |
1535 | + |
1536 | + pmatch = '(:program {{name: "{}"}})'.format(program_name) |
1537 | + fmatch = '(f:func {{name: "{}"}})'.format(function_name) |
1538 | + # be explicit about index usage |
1539 | + q = (' MATCH {}-[:subject]->{} USING INDEX f:func(name)' |
1540 | + ' RETURN f LIMIT 1'.format(pmatch, fmatch)) |
1541 | + |
1542 | + # result will be an empty list if the function was not found |
1543 | + result = self._db.query(q, returns=neo4jrestclient.Node) |
1544 | + return bool(result) |
1545 | |
1546 | def get_function_names(self, program_name): |
1547 | """ |
1548 | @@ -552,12 +842,11 @@ |
1549 | a set of function-name strings otherwise. |
1550 | """ |
1551 | |
1552 | - if not self.check_program_exists(program_name): |
1553 | - return None |
1554 | + if not validate_query(program_name): |
1555 | + return set() |
1556 | |
1557 | - q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' " |
1558 | - "MATCH (base)-[r:subject]->(m) " |
1559 | - "RETURN m.name").format(program_name) |
1560 | + q = (' MATCH (:program {{name: "{}"}})-[:subject]->(f:func)' |
1561 | + ' RETURN f.name').format(program_name) |
1562 | return {func[0] for func in self._db.query(q)} |
1563 | |
1564 | def get_all_functions_called(self, program_name, function_calling): |
1565 | @@ -570,16 +859,13 @@ |
1566 | :return: FunctionQueryResult, maximal subgraph rooted at function_calling |
1567 | """ |
1568 | |
1569 | - if not self.check_program_exists(program_name): |
1570 | - return None |
1571 | - |
1572 | if not self.check_function_exists(program_name, function_calling): |
1573 | return None |
1574 | |
1575 | - q = """MATCH (base) WHERE base.name = '{}' ANd base.type = 'program' |
1576 | - MATCH (base)-[:subject]->(m) WHERE m.name='{}' |
1577 | - MATCH (m)-[:calls*]->(n) |
1578 | - RETURN distinct n, m""".format(program_name, function_calling) |
1579 | + q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(f:func {{name: "{}"}})' |
1580 | + ' USING INDEX f:func(name)' |
1581 | + ' MATCH (f)-[:calls*]->(g) RETURN distinct f, g' |
1582 | + .format(program_name, function_calling)) |
1583 | |
1584 | return self._execute_query(program_name, q) |
1585 | |
1586 | @@ -593,16 +879,13 @@ |
1587 | :return: FunctionQueryResult, maximal connected subgraph with leaf function_called |
1588 | """ |
1589 | |
1590 | - if not self.check_program_exists(program_name): |
1591 | - return None |
1592 | - |
1593 | if not self.check_function_exists(program_name, function_called): |
1594 | return None |
1595 | |
1596 | - q = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program' |
1597 | - MATCH (base)-[r:subject]->(m) WHERE m.name='{}' |
1598 | - MATCH (n)-[:calls*]->(m) WHERE n.name <> '{}' |
1599 | - RETURN distinct n , m""" |
1600 | + q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(g:func {{name: "{}"}})' |
1601 | + ' USING INDEX g:func(name)' |
1602 | + ' MATCH (f)-[:calls*]->(g) WHERE f.name <> "{}"' |
1603 | + ' RETURN distinct f , g') |
1604 | q = q.format(program_name, function_called, program_name) |
1605 | |
1606 | return self._execute_query(program_name, q) |
1607 | @@ -628,12 +911,14 @@ |
1608 | if not self.check_function_exists(program_name, function_calling): |
1609 | return None |
1610 | |
1611 | - q = r"""MATCH (pr) WHERE pr.name = '{}' AND pr.type = 'program' |
1612 | - MATCH p=(start {{name: "{}" }})-[:calls*]->(end {{name:"{}"}}) |
1613 | - WHERE (pr)-[:subject]->(start) |
1614 | - WITH DISTINCT nodes(p) AS result |
1615 | - UNWIND result AS answer |
1616 | - RETURN answer""" |
1617 | + q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(start:func {{name: "{}"}})' |
1618 | + ' USING INDEX start:func(name)' |
1619 | + ' MATCH (p)-[:subject]->(end:func {{name: "{}"}})' |
1620 | + ' USING INDEX end:func(name)' |
1621 | + ' MATCH path=(start)-[:calls*]->(end)' |
1622 | + ' WITH DISTINCT nodes(path) AS result' |
1623 | + ' UNWIND result AS answer' |
1624 | + ' RETURN answer') |
1625 | q = q.format(program_name, function_calling, function_called) |
1626 | |
1627 | return self._execute_query(program_name, q) |
1628 | @@ -648,11 +933,9 @@ |
1629 | if not self.check_program_exists(program_name): |
1630 | return None |
1631 | |
1632 | - query = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program' |
1633 | - MATCH (base)-[subject:subject]->(m) |
1634 | - RETURN DISTINCT (m)""".format(program_name) |
1635 | - |
1636 | - return self._execute_query(program_name, query) |
1637 | + q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(f:func)' |
1638 | + ' RETURN (f)'.format(program_name)) |
1639 | + return self._execute_query(program_name, q) |
1640 | |
1641 | def get_shortest_path_between_functions(self, program_name, func1, func2): |
1642 | """ |
1643 | @@ -671,9 +954,11 @@ |
1644 | if not self.check_function_exists(program_name, func2): |
1645 | return None |
1646 | |
1647 | - q = """MATCH (func1 {{ name:"{}" }}),(func2 {{ name:"{}" }}), |
1648 | - p = shortestPath((func1)-[:calls*]->(func2)) |
1649 | - UNWIND nodes(p) AS ans |
1650 | - RETURN ans""".format(func1, func2) |
1651 | + q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(f:func {{name: "{}"}})' |
1652 | + ' USING INDEX f:func(name)' |
1653 | + ' MATCH (p)-[:subject]->(g:func {{name: "{}"}})' |
1654 | + ' MATCH path=shortestPath((f)-[:calls*]->(g))' |
1655 | + ' UNWIND nodes(path) AS ans' |
1656 | + ' RETURN ans'.format(program_name, func1, func2)) |
1657 | |
1658 | return self._execute_query(program_name, q) |
1659 | |
1660 | === modified file 'src/sextant/export.py' |
1661 | --- src/sextant/export.py 2014-09-04 09:46:18 +0000 |
1662 | +++ src/sextant/export.py 2014-10-23 12:33:12 +0000 |
1663 | @@ -46,7 +46,7 @@ |
1664 | font_name = "Helvetica" |
1665 | |
1666 | for func in program.get_functions(): |
1667 | - if func.type == "plt_stub": |
1668 | + if func.type == "stub": |
1669 | output_str += ' "{}" [fillcolor=pink, style=filled]\n'.format(func.name) |
1670 | elif func.type == "function_pointer": |
1671 | output_str += ' "{}" [fillcolor=yellow, style=filled]\n'.format(func.name) |
1672 | @@ -108,7 +108,7 @@ |
1673 | |
1674 | for func in program.get_functions(): |
1675 | display_func = ProgramConverter.get_display_name(func) |
1676 | - if func.type == "plt_stub": |
1677 | + if func.type == "stub": |
1678 | colour = "#ff00ff" |
1679 | elif func.type == "function_pointer": |
1680 | colour = "#99ffff" |
1681 | @@ -175,4 +175,4 @@ |
1682 | output_str += '<edge source="{}" target="{}"> <data key="calls">1</data> </edge>\n'.format(func.name, callee.name) |
1683 | |
1684 | output_str += '</graph>\n</graphml>' |
1685 | - return output_str |
1686 | \ No newline at end of file |
1687 | + return output_str |
1688 | |
1689 | === modified file 'src/sextant/objdump_parser.py' (properties changed: -x to +x) |
1690 | --- src/sextant/objdump_parser.py 2014-08-18 13:00:53 +0000 |
1691 | +++ src/sextant/objdump_parser.py 2014-10-23 12:33:12 +0000 |
1692 | @@ -1,273 +1,313 @@ |
1693 | -# ----------------------------------------- |
1694 | -# Sextant |
1695 | -# Copyright 2014, Ensoft Ltd. |
1696 | -# Author: Patrick Stevens |
1697 | -# ----------------------------------------- |
1698 | - |
1699 | -#!/usr/bin/python3 |
1700 | - |
1701 | -import re |
1702 | +#!/usr/bin/python |
1703 | import argparse |
1704 | -import os.path |
1705 | import subprocess |
1706 | import logging |
1707 | |
1708 | - |
1709 | -class ParsedObject(): |
1710 | - """ |
1711 | - Represents a function as parsed from an objdump disassembly. |
1712 | - Has a name (which is the verbatim name like '__libc_start_main@plt'), |
1713 | - a position (which is the virtual memory location in hex, like '08048320' |
1714 | - extracted from the dump), |
1715 | - and a canonical_position (which is the virtual memory location in hex |
1716 | - but stripped of leading 0s, so it should be a |
1717 | - unique id). |
1718 | - It also has a list what_do_i_call of ParsedObjects it calls using the |
1719 | - assembly keyword 'call'. |
1720 | - It has a list original_code of its assembler code, too, in case it's useful. |
1721 | - """ |
1722 | - |
1723 | - @staticmethod |
1724 | - def get_canonical_position(position): |
1725 | - return position.lstrip('0') |
1726 | - |
1727 | - def __eq__(self, other): |
1728 | - return self.name == other.name |
1729 | - |
1730 | - def __init__(self, input_lines=None, assembler_section='', function_name='', |
1731 | - ignore_function_pointers=True, function_pointer_id=None): |
1732 | - """ |
1733 | - Create a new ParsedObject given the definition-lines from objdump -S. |
1734 | - A sample first definition-line is '08048300 <__gmon_start__@plt>:\n' |
1735 | - but this method |
1736 | - expects to see the entire definition eg |
1737 | - |
1738 | -080482f0 <puts@plt>: |
1739 | - 80482f0: ff 25 00 a0 04 08 jmp *0x804a000 |
1740 | - 80482f6: 68 00 00 00 00 push $0x0 |
1741 | - 80482fb: e9 e0 ff ff ff jmp 80482e0 <_init+0x30> |
1742 | - |
1743 | - We also might expect assembler_section, which is for instance '.init' |
1744 | - in 'Disassembly of section .init:' |
1745 | - function_name is used if we want to give this function a custom name. |
1746 | - ignore_function_pointers=True will pretend that calls to (eg) *eax do |
1747 | - not exist; setting to False makes us create stubs for those calls. |
1748 | - function_pointer_id is only used internally; it refers to labelling |
1749 | - of function pointers if ignore_function_pointers is False. Each |
1750 | - stub is given a unique numeric ID: this parameter tells init where |
1751 | - to start counting these IDs from. |
1752 | - |
1753 | - """ |
1754 | - if input_lines is None: |
1755 | - # get around Python's inability to pass in empty lists by value |
1756 | - input_lines = [] |
1757 | - |
1758 | - self.name = function_name or re.search(r'<.+>', input_lines[0]).group(0).strip('<>') |
1759 | - self.what_do_i_call = [] |
1760 | - self.position = '' |
1761 | - |
1762 | - if input_lines: |
1763 | - self.position = re.search(r'^[0-9a-f]+', input_lines[0]).group(0) |
1764 | - self.canonical_position = ParsedObject.get_canonical_position(self.position) |
1765 | - self.assembler_section = assembler_section |
1766 | - self.original_code = input_lines[1:] |
1767 | +""" |
1768 | +Provide a parser class to extract functions and calls from an objdump file, |
1769 | +and a way to generate such a file from an object file. |
1770 | +""" |
1771 | +__all__ = ('Parser', 'run_objdump', 'FileNotFoundError') |
1772 | + |
1773 | + |
1774 | +class FileNotFoundError(Exception): |
1775 | + """ |
1776 | + Exception raised when Parser fails to open its file. |
1777 | + """ |
1778 | + pass |
1779 | + |
1780 | + |
1781 | +class Parser(object): |
1782 | + """ |
1783 | + Extract functions and calls from an object file or an objdump output file. |
1784 | + |
1785 | + Only the specified sections of the disassembled code will be parsed. |
1786 | + |
1787 | + Attributes: |
1788 | + path: |
1789 | + Set to file_path in __init__. |
1790 | + _file: |
1791 | + Set to file_object in __init__. |
1792 | + sections: |
1793 | + Initialised by taking the sections argument to __init__ and |
1794 | + and converting it to a set. |
1795 | + ignore_ptrs: |
1796 | + Set to ignore_ptrs in __init__. |
1797 | + |
1798 | + section_count: |
1799 | + The number of sections that have been parsed. |
1800 | + function_count: |
1801 | + The number of functions that have been parsed. |
1802 | + call_count: |
1803 | + The number of function calls that have been parsed. |
1804 | + function_ptr_count: |
1805 | + The number of function pointers that have been detected. |
1806 | + _known_stubs: |
1807 | + A set of the names of functions with type 'stub' that have been |
1808 | + parsed - used to avoid registering a stub multiple times. |
1809 | + |
1810 | + """ |
1811 | + def __init__(self, file_path, file_object=None, |
1812 | + sections=None, ignore_ptrs=False, |
1813 | + add_function=None, add_call=None, |
1814 | + started=None, finished=None): |
1815 | + """ |
1816 | + Initialise the parser object. |
1817 | + |
1818 | + Raises: |
1819 | + FileNotFoundError: |
1820 | + If file_object was not provided and file_path couldn't be |
1821 | + opened. |
1822 | + |
1823 | + Arguments: |
1824 | + file_path: |
1825 | + The path of the objdump output file to parse, or the path of an |
1826 | + object file to run objdump on and then parse. |
1827 | + file_object: |
1828 | + None if file_path is the path to an object file. |
1829 | + OR the file object (providing 'for line in file_object') |
1830 | + sections: |
1831 | + A list of the names of the disassembly sections to parse. An mepty |
1832 | + list will result in all sections being parsed. |
1833 | + ignore_ptrs: |
1834 | + If True, calls to function pointers will be ignored during parsing. |
1835 | + add_function: |
1836 | + A function to call when a function is parsed. Takes: |
1837 | + name: name of the parsed function |
1838 | + type: type of the parsed function |
1839 | + add_call: |
1840 | + A function to call when a function call is passed. Takes: |
1841 | + caller: name of the calling function |
1842 | + callee: name of the called function |
1843 | + started: |
1844 | + A function to call when the parse begins. Takes: |
1845 | + parser: the Parser instance which has just began parsing.. |
1846 | + finished: |
1847 | + A function to call when the parse completes. Takes: |
1848 | + parser: the Parser instance which has just finished parsing. |
1849 | + e.g. if add_function/call have been set to write into files, |
1850 | + then finished may be set to properly flush and close them. |
1851 | + """ |
1852 | + self.path = file_path |
1853 | + try: |
1854 | + self._file = file_object or self._open_file(file_path) |
1855 | + except FileNotFoundError: |
1856 | + raise |
1857 | + |
1858 | + self.sections = set(sections or []) |
1859 | + self.ignore_ptrs = ignore_ptrs |
1860 | + |
1861 | + self.section_count = 0 |
1862 | + self.function_count = 0 |
1863 | + self.call_count = 0 |
1864 | + self.function_ptr_count = 0 |
1865 | + |
1866 | + # Avoid adding duplicate function stubs (as these are detected from |
1867 | + # function calls so may be repeated). |
1868 | + self._known_stubs = set() |
1869 | + |
1870 | + # By default print information to stdout. |
1871 | + def print_func(name, typ): |
1872 | + print('func {:25}{}'.format(name, typ)) |
1873 | + |
1874 | + def print_call(caller, callee): |
1875 | + print('call {:25}{:25}'.format(caller, callee)) |
1876 | + |
1877 | + def print_started(parser): |
1878 | + print('parse started: {}[{}]'.format(self.path, ', '.join(self.sections))) |
1879 | + |
1880 | + |
1881 | + def print_finished(parser): |
1882 | + print('parsed {} functions and {} calls'.format(self.function_count, self.call_count)) |
1883 | + |
1884 | + self.add_function = add_function or print_func |
1885 | + self.add_call = add_call or print_call |
1886 | + self.started = lambda: (started or print_started)(self) |
1887 | + self.finished = lambda: (finished or print_finished)(self) |
1888 | + |
1889 | + |
1890 | + def _get_function_ptr_name(self): |
1891 | + """ |
1892 | + Return a name for a new function pointer. |
1893 | + """ |
1894 | + name = 'func_ptr_{}'.format(self.function_ptr_count) |
1895 | + self.function_ptr_count += 1 |
1896 | + return name |
1897 | + |
1898 | + def _add_function_normal(self, name): |
1899 | + """ |
1900 | + Add a function which we have full assembly code for. |
1901 | + """ |
1902 | + self.add_function(name, 'normal') |
1903 | + self.function_count += 1 |
1904 | + |
1905 | + def _add_function_ptr(self, name): |
1906 | + """ |
1907 | + Add a function pointer. |
1908 | + """ |
1909 | + self.add_function(name, 'pointer') |
1910 | + self.function_count += 1 |
1911 | + |
1912 | + def _add_function_stub(self, name): |
1913 | + """ |
1914 | + Add a function stub - we have its name but none of its internals. |
1915 | + """ |
1916 | + if not name in self._known_stubs: |
1917 | + self._known_stubs.add(name) |
1918 | + self.add_function(name, 'stub') |
1919 | + self.function_count += 1 |
1920 | + |
1921 | + def _add_call(self, caller, callee): |
1922 | + """ |
1923 | + Add a function call from caller to callee. |
1924 | + """ |
1925 | + self.add_call(caller, callee) |
1926 | + self.call_count += 1 |
1927 | + |
1928 | + def parse(self): |
1929 | + """ |
1930 | + Parse self._file. |
1931 | + """ |
1932 | + self.started() |
1933 | + |
1934 | + if self._file is not None: |
1935 | + in_section = False # if we are in one of self.sections |
1936 | + current_function = None # track the caller for function calls |
1937 | + |
1938 | + for line in self._file: |
1939 | + if line.startswith('Disassembly'): |
1940 | + # 'Disassembly of section <name>:\n' |
1941 | + section = line.split(' ')[-1].rstrip(':\n') |
1942 | + in_section = section in self.sections if self.sections else True |
1943 | + if in_section: |
1944 | + self.section_count += 1 |
1945 | + |
1946 | + elif in_section: |
1947 | + if line.endswith('>:\n'): |
1948 | + # '<address> <<function_identifier>>:\n' |
1949 | + # with <function_identifier> of form: |
1950 | + # <function_name>[@plt] |
1951 | + function_identifier = line.split('<')[-1].split('>')[0] |
1952 | + |
1953 | + if '@' in function_identifier: |
1954 | + current_function = function_identifier.split('@')[0] |
1955 | + self._add_function_stub(current_function) |
1956 | + else: |
1957 | + current_function = function_identifier |
1958 | + self._add_function_normal(current_function) |
1959 | + |
1960 | + elif 'call ' in line or 'callq ' in line: |
1961 | + # WHITESPACE to prevent picking up function names |
1962 | + # containing 'call' |
1963 | + |
1964 | + # '<hex>: <hex> [l]call [hex] <callee_info>\n' |
1965 | + callee_info = line.split(' ')[-1].rstrip('\n') |
1966 | + |
1967 | + # Where <callee_info> is either |
1968 | + # 1) '*(<register>)' call to a fn pointer |
1969 | + # 2) '$<hex>,$<hex>' lcall to a fn pointer |
1970 | + # 3) '<<function_identifier>>' call to a named function |
1971 | + if '<' in callee_info and '>' in callee_info: |
1972 | + # call to a normal or stub function |
1973 | + # '<function_identifier>' is of form <name>[@/-/+]<...> |
1974 | + # from which we extract name |
1975 | + callee_is_ptr = False |
1976 | + function_identifier = callee_info.lstrip('<').rstrip('>\n') |
1977 | + if '@' in function_identifier: |
1978 | + callee = function_identifier.split('@')[0] |
1979 | + self._add_function_stub(callee) |
1980 | + else: |
1981 | + callee = function_identifier.split('-')[-1].split('+')[0] |
1982 | + # Do not add this fn now - it is a normal func |
1983 | + # so we know about it from elsewhere. |
1984 | + |
1985 | + else: |
1986 | + # Some kind of function pointer call. |
1987 | + callee_is_ptr = True |
1988 | + if not self.ignore_ptrs: |
1989 | + callee = self._get_function_ptr_name() |
1990 | + self._add_function_ptr(callee) |
1991 | + |
1992 | + # Add the call. |
1993 | + if not (self.ignore_ptrs and callee_is_ptr): |
1994 | + self._add_call(current_function, callee) |
1995 | |
1996 | - call_regex_compiled = (ignore_function_pointers and re.compile(r'\tcall. +[^\*]+\n')) or re.compile(r'\tcall. +.+\n') |
1997 | - |
1998 | - lines_where_i_call = [line for line in input_lines if call_regex_compiled.search(line)] |
1999 | - |
2000 | - if not ignore_function_pointers and not function_pointer_id: |
2001 | - function_pointer_id = [1] |
2002 | - |
2003 | - for line in lines_where_i_call: |
2004 | - # we'll catch call and callq for the moment |
2005 | - called = (call_regex_compiled.search(line).group(0))[8:].lstrip(' ').rstrip('\n') |
2006 | - if called[0] == '*' and ignore_function_pointers == False: |
2007 | - # we have a function pointer, which we'll want to give a distinct name |
2008 | - address = '0' |
2009 | - name = '_._function_pointer_' + str(function_pointer_id[0]) |
2010 | - function_pointer_id[0] += 1 |
2011 | - |
2012 | - self.what_do_i_call.append((address, name)) |
2013 | - |
2014 | - else: # we're not on a function pointer |
2015 | - called_split = called.split(' ') |
2016 | - if len(called_split) == 2: |
2017 | - address, name = called_split |
2018 | - name = name.strip('<>') |
2019 | - # we still want to remove address offsets like +0x09 from the end of name |
2020 | - match = re.match(r'^.+(?=\+0x[a-f0-9]+$)', name) |
2021 | - if match is not None: |
2022 | - name = match.group(0) |
2023 | - self.what_do_i_call.append((address, name.strip('<>'))) |
2024 | - else: # the format of the "what do i call" is not recognised as a name/address pair |
2025 | - self.what_do_i_call.append(tuple(called_split)) |
2026 | - |
2027 | - def __str__(self): |
2028 | - if self.position: |
2029 | - return 'Memory address ' + self.position + ' with name ' + self.name + ' in section ' + str( |
2030 | - self.assembler_section) |
2031 | + self.finished() |
2032 | + |
2033 | + self._file.close() |
2034 | + result = True |
2035 | else: |
2036 | - return 'Name ' + self.name |
2037 | - |
2038 | - def __repr__(self): |
2039 | - out_str = 'Disassembly of section ' + self.assembler_section + ':\n\n' + self.position + ' ' + self.name + ':\n' |
2040 | - return out_str + '\n'.join([' ' + line for line in self.original_code]) |
2041 | - |
2042 | - |
2043 | -class Parser: |
2044 | - # Class to manipulate the output of objdump |
2045 | - |
2046 | - def __init__(self, input_file_location='', file_contents=None, sections_to_view=None, ignore_function_pointers=False): |
2047 | - """Creates a new Parser, given an input file path. That path should be an output from objdump -D. |
2048 | - Alternatively, supply file_contents, as a list of each line of the objdump output. We expect newlines |
2049 | - to have been stripped from the end of each of these lines. |
2050 | - sections_to_view makes sure we only use the specified sections (use [] for 'all sections' and None for none). |
2051 | - """ |
2052 | - if file_contents is None: |
2053 | - file_contents = [] |
2054 | - |
2055 | - if sections_to_view is None: |
2056 | - sections_to_view = [] |
2057 | - |
2058 | - if input_file_location: |
2059 | - file_to_read = open(input_file_location, 'r') |
2060 | - self.source_string_list = [line for line in file_to_read] |
2061 | - file_to_read.close() |
2062 | - elif file_contents: |
2063 | - self.source_string_list = [string + '\n' for string in file_contents] |
2064 | - self.parsed_objects = [] |
2065 | - self.sections_to_view = sections_to_view |
2066 | - self.ignore_function_pointers = ignore_function_pointers |
2067 | - self.pointer_identifier = [1] |
2068 | - |
2069 | - def create_objects(self): |
2070 | - """ Go through the source_string_list, getting object names (like 'main') along with the corresponding |
2071 | - definitions, and put them into parsed_objects """ |
2072 | - if self.sections_to_view is None: |
2073 | - return |
2074 | - |
2075 | - is_in_section = lambda name: self.sections_to_view == [] or name in self.sections_to_view |
2076 | - |
2077 | - parsed_objects = [] |
2078 | - current_object = [] |
2079 | - current_section = '' |
2080 | - regex_compiled_addr_and_name = re.compile(r'[0-9a-f]+ <.+>:\n') |
2081 | - regex_compiled_section = re.compile(r'section .+:\n') |
2082 | - |
2083 | - for line in self.source_string_list[4:]: # we bodge, since the file starts with a little bit of guff |
2084 | - if regex_compiled_addr_and_name.match(line): |
2085 | - # we are a starting line |
2086 | - current_object = [line] |
2087 | - elif re.match(r'Disassembly of section', line): |
2088 | - current_section = regex_compiled_section.search(line).group(0).lstrip('section ').rstrip(':\n') |
2089 | - current_object = [] |
2090 | - elif line == '\n': |
2091 | - # we now need to stop parsing the current block, and store it |
2092 | - if len(current_object) > 0 and is_in_section(current_section): |
2093 | - parsed_objects.append(ParsedObject(input_lines=current_object, assembler_section=current_section, |
2094 | - ignore_function_pointers=self.ignore_function_pointers, |
2095 | - function_pointer_id=self.pointer_identifier)) |
2096 | - else: |
2097 | - current_object.append(line) |
2098 | - |
2099 | - # now we should be done. We assumed that blocks begin with r'[0-9a-f]+ <.+>:\n' and end with a newline. |
2100 | - # clear duplicates: |
2101 | - |
2102 | - self.parsed_objects = [] |
2103 | - for obj in parsed_objects: |
2104 | - if obj not in self.parsed_objects: # this is so that if we jump into the function at an offset, |
2105 | - # we still register it as being the old function, not some new function at a different address |
2106 | - # with the same name |
2107 | - self.parsed_objects.append(obj) |
2108 | - |
2109 | - # by this point, each object contains a self.what_do_i_call which is a list of tuples |
2110 | - # ('address', 'name') if the address and name were recognised, or else (thing1, thing2, ...) |
2111 | - # where the instruction was call thing1 thing2 thing3... . |
2112 | - |
2113 | - def object_lookup(self, object_name='', object_address=''): |
2114 | - """Returns the object with name object_name or address object_address (at least one must be given). |
2115 | - If objects with the given name or address |
2116 | - are not found, returns None.""" |
2117 | - |
2118 | - if object_name == '' and object_address == '': |
2119 | - return None |
2120 | - |
2121 | - trial_obj = self.parsed_objects |
2122 | - |
2123 | - if object_name != '': |
2124 | - trial_obj = [obj for obj in trial_obj if obj.name == object_name] |
2125 | - |
2126 | - if object_address != '': |
2127 | - trial_obj = [obj for obj in trial_obj if |
2128 | - obj.canonical_position == ParsedObject.get_canonical_position(object_address)] |
2129 | - |
2130 | - if len(trial_obj) == 0: |
2131 | - return None |
2132 | - |
2133 | - return trial_obj |
2134 | - |
2135 | -def get_parsed_objects(filepath, sections_to_view, not_object_file, readable=False, ignore_function_pointers=False): |
2136 | - if sections_to_view is None: |
2137 | - sections_to_view = [] # because we use None for "no sections"; the intent of not providing any sections |
2138 | - # on the command line was to look at all sections, not none |
2139 | - |
2140 | - # first, check whether the given file exists |
2141 | - if not os.path.isfile(filepath): |
2142 | - # we'd like to use FileNotFoundError, but we might be running under |
2143 | - # Python 2, which doesn't have it. |
2144 | - raise IOError(filepath + 'is not found.') |
2145 | - |
2146 | - #now the file should exist |
2147 | - if not not_object_file: #if it is something we need to run through objdump first |
2148 | - #we need first to run the object file through objdump |
2149 | - |
2150 | - objdump_file_contents = subprocess.check_output(['objdump', '-D', filepath]) |
2151 | - objdump_str = objdump_file_contents.decode('utf-8') |
2152 | - |
2153 | - p = Parser(file_contents=objdump_str.split('\n'), sections_to_view=sections_to_view, ignore_function_pointers=ignore_function_pointers) |
2154 | - else: |
2155 | + result = False |
2156 | + |
2157 | + return result |
2158 | + |
2159 | + def _open_file(self, path): |
2160 | + """ |
2161 | + Open and return the file at path. |
2162 | + |
2163 | + Raises: |
2164 | + FileNotFoundError: |
2165 | + If the file fails to open. |
2166 | + |
2167 | + Arguments: |
2168 | + path: |
2169 | + The path of the file to open. |
2170 | + """ |
2171 | try: |
2172 | - p = Parser(input_file_location=filepath, sections_to_view=sections_to_view, ignore_function_pointers=ignore_function_pointers) |
2173 | - except UnicodeDecodeError: |
2174 | - logging.error('File could not be parsed as a string. Did you mean to supply --object-file?') |
2175 | - return False |
2176 | - |
2177 | - if readable: # if we're being called from the command line |
2178 | - print('File read; beginning parse.') |
2179 | - #file is now read, and we start parsing |
2180 | - |
2181 | - p.create_objects() |
2182 | - return p.parsed_objects |
2183 | + result = open(path) |
2184 | + except Exception as e: |
2185 | + raise FileNotFoundError("parser failed to open `{}`: {}".format(path, e.strerror)) |
2186 | + |
2187 | + return result |
2188 | + |
2189 | + |
2190 | +def run_objdump(input_file): |
2191 | + """ |
2192 | + Run the objdump command on the file with the given path. |
2193 | + |
2194 | + Return the input file path and a file object representing the result of |
2195 | + the objdump. |
2196 | + |
2197 | + Arguments: |
2198 | + input_file: |
2199 | + The path of the file to run objdump on. |
2200 | + |
2201 | + """ |
2202 | + # A single section can be specified for parsing with the -j flag, |
2203 | + # but it is not obviously possible to parse multiple sections like this. |
2204 | + p = subprocess.Popen(['objdump', '-d', input_file, '--no-show-raw-insn'], |
2205 | + stdout=subprocess.PIPE) |
2206 | + g = subprocess.Popen(['egrep', 'Disassembly|call(q)? |>:$'], stdin=p.stdout, stdout=subprocess.PIPE) |
2207 | + return input_file, g.stdout |
2208 | + |
2209 | |
2210 | def main(): |
2211 | - argumentparser = argparse.ArgumentParser(description="Parse the output of objdump.") |
2212 | - argumentparser.add_argument('--filepath', metavar="FILEPATH", help="path to input file", type=str, nargs=1) |
2213 | - argumentparser.add_argument('--not-object-file', help="import text objdump output instead of the compiled file", default=False, |
2214 | - action='store_true') |
2215 | - argumentparser.add_argument('--sections-to-view', metavar="SECTIONS", |
2216 | - help="sections of disassembly to view, like '.text'; leave blank for 'all'", |
2217 | - type=str, nargs='*') |
2218 | - argumentparser.add_argument('--ignore-function-pointers', help='whether to skip parsing calls to function pointers', action='store_true', default=False) |
2219 | - |
2220 | - parsed = argumentparser.parse_args() |
2221 | + """ |
2222 | + Run the parser from the command line. |
2223 | + |
2224 | + The path of the target file, the sections to view and the ignore function |
2225 | + pointers flag are set with command line arguments. |
2226 | + """ |
2227 | + ap = argparse.ArgumentParser(description="Parse the output of objdump.") |
2228 | + ap.add_argument('--filepath', metavar="FILEPATH", |
2229 | + help="path to input file", type=str, nargs=1) |
2230 | + |
2231 | + ap.add_argument('--sections-to-view', metavar="SECTIONS", |
2232 | + help="disassembly sections to view, eg '.text'; leave blank for 'all'", |
2233 | + type=str, nargs='*') |
2234 | + ap.add_argument('--ignore-function-pointers', |
2235 | + help='skip parsing calls to function pointers', |
2236 | + action='store_true', default=False) |
2237 | + |
2238 | + args = ap.parse_args() |
2239 | |
2240 | - filepath = parsed.filepath[0] |
2241 | - sections_to_view = parsed.sections_to_view |
2242 | - not_object_file = parsed.not_object_file |
2243 | - readable = True |
2244 | - function_pointers = parsed.ignore_function_pointers |
2245 | - |
2246 | - parsed_objs = get_parsed_objects(filepath, sections_to_view, not_object_file, readable, function_pointers) |
2247 | - if parsed_objs is False: |
2248 | - return 1 |
2249 | - |
2250 | - if readable: |
2251 | - for named_function in parsed_objs: |
2252 | - print(named_function.name) |
2253 | - print([f[-1] for f in named_function.what_do_i_call]) # use [-1] to get the last element, since: |
2254 | - #either we are in ('address', 'name'), when we want the last element, or else we are in (thing1, thing2, ...) |
2255 | - #so for the sake of argument we'll take the last thing |
2256 | - |
2257 | -if __name__ == "__main__": |
2258 | + filepath = args.filepath[0] |
2259 | + sections = args.sections_to_view |
2260 | + ignore_ptrs = args.ignore_function_pointers |
2261 | + |
2262 | + parser = Parser(filepath, sections, ignore_ptrs) |
2263 | + parser.parse() |
2264 | + |
2265 | + |
2266 | +if __name__ == '__main__': |
2267 | main() |
2268 | |
2269 | === modified file 'src/sextant/query.py' |
2270 | --- src/sextant/query.py 2014-08-26 16:33:20 +0000 |
2271 | +++ src/sextant/query.py 2014-10-23 12:33:12 +0000 |
2272 | @@ -14,7 +14,7 @@ |
2273 | from .export import ProgramConverter |
2274 | |
2275 | |
2276 | -def query(remote_neo4j, input_query, display_neo4j='', program_name=None, |
2277 | +def query(connection, display_neo4j='', program_name=None, |
2278 | argument_1=None, argument_2=None, suppress_common=False): |
2279 | """ |
2280 | Run a query against the database at remote_neo4j. |
2281 | @@ -36,24 +36,24 @@ |
2282 | |
2283 | """ |
2284 | |
2285 | - if display_neo4j: |
2286 | - display_url = display_neo4j |
2287 | - else: |
2288 | - display_url = remote_neo4j |
2289 | + # if display_neo4j: |
2290 | + # display_url = display_neo4j |
2291 | + # else: |
2292 | + # display_url = remote_neo4j |
2293 | |
2294 | - try: |
2295 | - db = db_api.SextantConnection(remote_neo4j) |
2296 | - except requests.exceptions.ConnectionError as err: |
2297 | - logging.error("Could not connect to Neo4J server {}. Are you sure it is running?".format(display_url)) |
2298 | - logging.error(str(err)) |
2299 | - return 2 |
2300 | - #Not supported in python 2 |
2301 | - #except (urllib.exceptions.MaxRetryError): |
2302 | - # logging.error("Connection was refused to {}. Are you sure the server is running?".format(remote_neo4j)) |
2303 | - # return 2 |
2304 | - except Exception as err: |
2305 | - logging.exception(str(err)) |
2306 | - return 2 |
2307 | + # try: |
2308 | + # db = db_api.SextantConnection(remote_neo4j) |
2309 | + # except requests.exceptions.ConnectionError as err: |
2310 | + # logging.error("Could not connect to Neo4J server {}. Are you sure it is running?".format(display_url)) |
2311 | + # logging.error(str(err)) |
2312 | + # return 2 |
2313 | + # #Not supported in python 2 |
2314 | + # #except (urllib.exceptions.MaxRetryError): |
2315 | + # # logging.error("Connection was refused to {}. Are you sure the server is running?".format(remote_neo4j)) |
2316 | + # # return 2 |
2317 | + # except Exception as err: |
2318 | + # logging.exception(str(err)) |
2319 | + # return 2 |
2320 | |
2321 | prog = None |
2322 | names_list = None |
2323 | @@ -66,38 +66,38 @@ |
2324 | if argument_1 is None: |
2325 | print('Supply one function name to functions-calling.') |
2326 | return 1 |
2327 | - prog = db.get_all_functions_calling(program_name, argument_1) |
2328 | + prog = connection.get_all_functions_calling(program_name, argument_1) |
2329 | elif input_query == 'functions-called-by': |
2330 | if argument_1 is None: |
2331 | print('Supply one function name to functions-called-by.') |
2332 | return 1 |
2333 | - prog = db.get_all_functions_called(program_name, argument_1) |
2334 | + prog = connection.get_all_functions_called(program_name, argument_1) |
2335 | elif input_query == 'all-call-paths': |
2336 | if argument_1 is None and argument_2 is None: |
2337 | print('Supply two function names to calls-between.') |
2338 | return 1 |
2339 | - prog = db.get_call_paths(program_name, argument_1, argument_2) |
2340 | + prog = connection.get_call_paths(program_name, argument_1, argument_2) |
2341 | elif input_query == 'whole-program': |
2342 | - prog = db.get_whole_program(program_name) |
2343 | + prog = connection.get_whole_program(program_name) |
2344 | elif input_query == 'shortest-call-path': |
2345 | if argument_1 is None and argument_2 is None: |
2346 | print('Supply two function names to shortest-path.') |
2347 | return 1 |
2348 | - prog = db.get_shortest_path_between_functions(program_name, argument_1, argument_2) |
2349 | + prog = connection.get_shortest_path_between_functions(program_name, argument_1, argument_2) |
2350 | elif input_query == 'functions': |
2351 | if program_name is not None: |
2352 | - func_names = db.get_function_names(program_name) |
2353 | + func_names = connection.get_function_names(program_name) |
2354 | if func_names: |
2355 | names_list = list(func_names) |
2356 | else: |
2357 | print('No functions were found in program %s on server %s.' % (program_name, display_url)) |
2358 | else: |
2359 | - list_of_programs = db.get_program_names() |
2360 | + list_of_programs = connection.get_program_names() |
2361 | if not list_of_programs: |
2362 | print('Server %s database empty.' % (display_url)) |
2363 | return 0 |
2364 | |
2365 | - func_list = [db.get_function_names(prog_name) |
2366 | + func_list = [connection.get_function_names(prog_name) |
2367 | for prog_name in list_of_programs] |
2368 | |
2369 | if not func_list: |
2370 | @@ -105,7 +105,7 @@ |
2371 | else: |
2372 | names_list = func_list |
2373 | elif input_query == 'programs': |
2374 | - list_found = list(db.get_program_names()) |
2375 | + list_found = list(connection.get_program_names()) |
2376 | if not list_found: |
2377 | print('No programs were found on server {}.'.format(display_url)) |
2378 | else: |
2379 | @@ -122,7 +122,5 @@ |
2380 | print('Nothing was returned from the query.') |
2381 | |
2382 | |
2383 | -def audit(remote_neo4j): |
2384 | - db = db_api.SextantConnection(remote_neo4j) |
2385 | - |
2386 | - return db.programs_with_metadata() |
2387 | +def audit(connection): |
2388 | + return connection.programs_with_metadata() |
2389 | |
2390 | === added file 'src/sextant/sshmanager.py' |
2391 | --- src/sextant/sshmanager.py 1970-01-01 00:00:00 +0000 |
2392 | +++ src/sextant/sshmanager.py 2014-10-23 12:33:12 +0000 |
2393 | @@ -0,0 +1,278 @@ |
2394 | +import os |
2395 | +import getpass |
2396 | +import logging |
2397 | +import subprocess |
2398 | + |
2399 | +"""Provide a class to manage an SSH tunnel and controller""" |
2400 | +__all__ = ('SSHConnectionError', 'SSHCommandError', 'SSHManager') |
2401 | + |
2402 | +# The location of the temporary directory to create on the REMOTE machine. |
2403 | +# Temporary files will be scp'd here prior to upload to the neo4j database. |
2404 | +TMP_DIR = '/tmp/sextant' |
2405 | + |
2406 | + |
2407 | +class SSHConnectionError(Exception): |
2408 | + """ |
2409 | + An exception raised when an attempt to establish an ssh conneciton fails. |
2410 | + """ |
2411 | + pass |
2412 | + |
2413 | + |
2414 | +class SSHCommandError(Exception): |
2415 | + """ |
2416 | + An exception raised when an attempt to run a command over ssh fails. |
2417 | + """ |
2418 | + pass |
2419 | + |
2420 | + |
2421 | +class SSHManager(object): |
2422 | + """ |
2423 | + Manage an ssh tunnel with port forwarding. |
2424 | + |
2425 | + Attributes: |
2426 | + local_port: |
2427 | + The port number on the local machine to forward. |
2428 | + remote_host: |
2429 | + The host to ssh into. |
2430 | + remote_port: |
2431 | + The port number on the remote host to connect to. |
2432 | + ssh_user: |
2433 | + The username to use for sshing - defaults to None, in which case |
2434 | + the ssh connection uses the username of the user who ran sextant. |
2435 | + |
2436 | + _controller_name: |
2437 | + The base of the identifying name for the ssh controller - the |
2438 | + actual name will be a combination of this and the local port. |
2439 | + _is_localhost: |
2440 | + True if we are trying to ssh into localhost. In this case do not |
2441 | + open the tunnel, just provide the right api so the rest of Sextant |
2442 | + need not special case. |
2443 | + """ |
2444 | + |
2445 | + def __init__(self, local_port, remote_host, remote_port, |
2446 | + ssh_user=None, is_localhost=False): |
2447 | + """ |
2448 | + Open an SSH tunnel with multiplexing enabled. |
2449 | + |
2450 | + Raises: |
2451 | + ValueError: |
2452 | + If local_port or remote_port are not positive integers |
2453 | + |
2454 | + Arguments: |
2455 | + local_port: |
2456 | + The number of the local port to forward. |
2457 | + remote_host: |
2458 | + The name of the remote host to connect to. |
2459 | + remote_port: |
2460 | + The port number on the remote host to connect to. |
2461 | + ssh_user: |
2462 | + An alternative user name to use for the ssh login. |
2463 | + is_localhost: |
2464 | + True if we are trying to ssh into localhost. |
2465 | + """ |
2466 | + if not (isinstance(local_port, int) and local_port > 0): |
2467 | + raise ValueError( |
2468 | + 'Local port {} must be a positive integer.'.format(local_port)) |
2469 | + if not (isinstance(remote_port, int) and remote_port > 0): |
2470 | + raise ValueError( |
2471 | + 'Remote port {} must be a positive integer.'.format(remote_port)) |
2472 | + |
2473 | + self.local_port = local_port |
2474 | + self.remote_host = remote_host |
2475 | + self.remote_port = remote_port |
2476 | + self.ssh_user = ssh_user |
2477 | + |
2478 | + self._tmp_dir = '{}-{}'.format(TMP_DIR, self.ssh_user or getpass.getuser()) |
2479 | + |
2480 | + self._controller_name = 'sextantcontroller{}'.format(local_port) |
2481 | + self._is_localhost = is_localhost |
2482 | + |
2483 | + self._open() |
2484 | + |
2485 | + def _open(self): |
2486 | + """ |
2487 | + Helper function to open the SSH tunnel. |
2488 | + |
2489 | + Raises: |
2490 | + SSHConnectionError: |
2491 | + If the ssh command failed to run. |
2492 | + """ |
2493 | + if self._is_localhost: |
2494 | + return |
2495 | + |
2496 | + # This cmd string will be .format()ed in a few lines' time. |
2497 | + cmd = ['ssh'] |
2498 | + |
2499 | + if self.ssh_user: |
2500 | + # ssh -l {user} ... sets the remote login username |
2501 | + cmd.extend(['-l', self.ssh_user]) |
2502 | + |
2503 | + # -L localport:localhost:remoteport forwards the port. |
2504 | + port_fwd = '{}:localhost:{}'.format(self.local_port, self.remote_port) |
2505 | + |
2506 | + # -M makes SSH able to accept slave connections. |
2507 | + # -S sets the location of a control socket (in this case, sextantcontroller. |
2508 | + # with a unique identifier appended, just in case we run sextant twice. |
2509 | + # simultaneously), so we know how to close the port again. |
2510 | + # -f goes into background; -N does not execute a remote command; |
2511 | + # -T says to remote host that we don't want a text shell. |
2512 | + cmd.extend(['-M', '-S', self._controller_name, '-fNT', |
2513 | + '-L', port_fwd, self.remote_host]) |
2514 | + |
2515 | + logging.debug('Opening SSH tunnel with cmd: {}'.format(' '.join(cmd))) |
2516 | + |
2517 | + rc = subprocess.call(cmd) |
2518 | + if rc: |
2519 | + raise SSHConnectionError('SSH setup failed with error {}'.format(rc)) |
2520 | + |
2521 | + logging.debug('SSH tunnel created') |
2522 | + |
2523 | + self._make_tmp_dir() |
2524 | + |
2525 | + def close(self): |
2526 | + """ |
2527 | + Close the SSH tunnel after cleaning the temp directory. |
2528 | + """ |
2529 | + if self._is_localhost: |
2530 | + return |
2531 | + |
2532 | + # ssh -O sends a command to the slave specified in -S, -q for quiet. |
2533 | + cmd = ['ssh', '-S', self._controller_name, |
2534 | + '-O', 'exit', '-q', self.remote_host] |
2535 | + |
2536 | + logging.debug('Shutting down SSH tunnel with cmd: `{}`' |
2537 | + .format(' '.join(cmd))) |
2538 | + |
2539 | + # SSH has a bug on some systems which causes it to ignore the -q flag |
2540 | + # meaning it prints "Exit request sent." to stderr. |
2541 | + # To avoid this, we grab stderr temporarily, and see if it's that string; |
2542 | + # if it is, suppress it. |
2543 | + pr = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
2544 | + stdout, stderr = pr.communicate() |
2545 | + if stderr.rstrip() != 'Exit request sent.': |
2546 | + logging.error('SSH shutdown stderr: {}'.format(stderr)) |
2547 | + |
2548 | + if pr.returncode == 0: |
2549 | + logging.debug('Shut down successfully') |
2550 | + else: |
2551 | + logging.error('SSH shutdown failed with code {}' |
2552 | + .format(pr.returncode)) |
2553 | + |
2554 | + # Clean the temporary directory we created earlier. |
2555 | + self._delete_tmp_dir() |
2556 | + |
2557 | + def _call(self, *args): |
2558 | + """ |
2559 | + Execute a command on the remote machine over SSH. |
2560 | + |
2561 | + Return a tuple of rc, stdout, stderr from the process call. |
2562 | + |
2563 | + Arguments: |
2564 | + *args: |
2565 | + Strings containing the individual words of the command to |
2566 | + execute. E.g. _call('ls', '-lh', '.'). |
2567 | + """ |
2568 | + if self._is_localhost: |
2569 | + return (1, None, 'Cannot call SSH command from localhost') |
2570 | + |
2571 | + ssh_cmd = ['ssh', '-S', self._controller_name, self.remote_host] |
2572 | + ssh_cmd.extend(args) |
2573 | + p = subprocess.Popen(ssh_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
2574 | + stdout, stderr = p.communicate() |
2575 | + |
2576 | + if p.returncode: |
2577 | + logging.debug('Call to `{}` failed with code: {}, stderr: {}' |
2578 | + .format(' '.join(ssh_cmd), p.returncode, stderr)) |
2579 | + |
2580 | + return p.returncode, stdout, stderr |
2581 | + |
2582 | + def _make_tmp_dir(self): |
2583 | + """ |
2584 | + Create the per-user temporary directory on the remote machine. |
2585 | + """ |
2586 | + self._call('mkdir', '-p', self._tmp_dir) |
2587 | + |
2588 | + def _delete_tmp_dir(self): |
2589 | + """ |
2590 | + Remove the temporary directory on the remote machine. |
2591 | + """ |
2592 | + self._call('rm', '-r', self._tmp_dir) |
2593 | + |
2594 | + |
2595 | + def send_to_tmp_dir(self, path_list): |
2596 | + """ |
2597 | + Send the specified files to the temporary directory on the remote machine. |
2598 | + |
2599 | + Return an iterator of save paths on the remote machine. |
2600 | + Raises: |
2601 | + ValueError: |
2602 | + If no file paths were provided, or if one or more of the |
2603 | + provided paths is not an actual file. |
2604 | + SSHCommandError: |
2605 | + If the scp command failed for any reason. |
2606 | + |
2607 | + Arguments: |
2608 | + path_list: |
2609 | + Iterator of paths to the files on the local machine. All files |
2610 | + will be checked before copying to ensure that they exist and |
2611 | + to prevent passing arbitrary arguments to the ssh _call |
2612 | + command. |
2613 | + """ |
2614 | + if not path_list: |
2615 | + raise ValueError('attempt to copy zero files') |
2616 | + |
2617 | + # If we are in localhost, we are not controlling the TMP_DIR, |
2618 | + # so the files are already there. |
2619 | + if self._is_localhost: |
2620 | + return path_list |
2621 | + |
2622 | + # Make sure we can take the len of path_list and iterate over it |
2623 | + # more than once. |
2624 | + path_list = list(path_list) |
2625 | + |
2626 | + # Check that actual files are being copied - not random strings. |
2627 | + to_copy = [f for f in path_list if os.path.isfile(f)] |
2628 | + |
2629 | + if len(to_copy) < len(path_list): |
2630 | + missed = [f for f in path_list if not f in to_copy] |
2631 | + raise ValueError('Attempted to copy non existant files: {}' |
2632 | + .format(', '.join(missed))) |
2633 | + |
2634 | + scp_cmd = ['scp'] |
2635 | + scp_cmd.extend(to_copy) |
2636 | + scp_cmd.append('{}:{}'.format(self.remote_host, self._tmp_dir)) |
2637 | + |
2638 | + proc = subprocess.Popen(scp_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
2639 | + rc = proc.wait() |
2640 | + if rc: |
2641 | + raise SSHCommandError('scp failed with code {}: {}'.format(rc, stderr)) |
2642 | + |
2643 | + return (os.path.join(self._tmp_dir, os.path.basename(f)) for f in to_copy) |
2644 | + |
2645 | + def remove_from_tmp_dir(self, path_list): |
2646 | + """ |
2647 | + Delete the files specified as arguments from the remote machine. |
2648 | + |
2649 | + The output of send_to_tmp_dir may be passed as input to this function. |
2650 | + |
2651 | + Raises: |
2652 | + SSHCommandError: |
2653 | + If the rm command fails for any reason. |
2654 | + |
2655 | + Arguments: |
2656 | + path_list: |
2657 | + Iterator of paths of the files on the remote machine, relative |
2658 | + to the temporary directory. E.g. remove_from_tmp_dir('foo') |
2659 | + will delete the file self._tmp_dir/foo |
2660 | + """ |
2661 | + if self._is_localhost: |
2662 | + return |
2663 | + |
2664 | + # Assume we can trust this file list. |
2665 | + paths = [os.path.join(self._tmp_dir, os.path.basename(f)) for f in path_list] |
2666 | + self._call('rm', *paths) |
2667 | + |
2668 | + |
2669 | + |
2670 | + |
2671 | + |
2672 | |
2673 | === added file 'src/sextant/test_all.sh' |
2674 | --- src/sextant/test_all.sh 1970-01-01 00:00:00 +0000 |
2675 | +++ src/sextant/test_all.sh 2014-10-23 12:33:12 +0000 |
2676 | @@ -0,0 +1,4 @@ |
2677 | +#!/usr/bin/bash |
2678 | + |
2679 | +PYTHONPATH=$PYTHONPATH:~/. |
2680 | +python -m unittest discover --pattern=test_*.py |
2681 | |
2682 | === added file 'src/sextant/test_csvwriter.py' |
2683 | --- src/sextant/test_csvwriter.py 1970-01-01 00:00:00 +0000 |
2684 | +++ src/sextant/test_csvwriter.py 2014-10-23 12:33:12 +0000 |
2685 | @@ -0,0 +1,89 @@ |
2686 | +#!/usr/bin/python |
2687 | +import unittest |
2688 | +from csvwriter import CSVWriter |
2689 | +import subprocess |
2690 | +from os import listdir |
2691 | + |
2692 | +class TestSequence(unittest.TestCase): |
2693 | + def get_writer(self, path='tmp_test', headers=['name', 'type'], split=100): |
2694 | + return CSVWriter(path, headers, split) |
2695 | + |
2696 | + def tearDown(self): |
2697 | + to_rm = [f for f in listdir('.') if f.startswith('tmp_test') and f.endswith('.csv')] |
2698 | + if to_rm: |
2699 | + rc = subprocess.call(['rm'] + to_rm) |
2700 | + if rc: |
2701 | + msg = 'failed to clean' |
2702 | + else: |
2703 | + msg = 'cleaned' |
2704 | + print('{} {} files {}'.format(msg, len(to_rm), to_rm)) |
2705 | + |
2706 | + def test_headers(self): |
2707 | + # check that headers are being written correctly |
2708 | + headers = ['some', 'headers', 'to', 'check'] |
2709 | + writer = self.get_writer(headers=headers) |
2710 | + writer.finish() |
2711 | + |
2712 | + expected_path = 'tmp_test0.csv' |
2713 | + self.assertEquals(writer.file_iter().next(), expected_path) |
2714 | + writer_file = open('tmp_test0.csv', 'r') |
2715 | + |
2716 | + self.assertEquals(writer_file.readline(), 'some,headers,to,check\n') |
2717 | + self.assertFalse(writer_file.readline()) # check that nothing extra is written |
2718 | + |
2719 | + writer_file.close() |
2720 | + |
2721 | + def test_writing(self): |
2722 | + # check that csv entries are written correctly, and errors |
2723 | + # appropriately raised for invalid input |
2724 | + writer = self.get_writer() |
2725 | + |
2726 | + self.assertRaises(ValueError, writer.write, 'too short') |
2727 | + self.assertRaises(ValueError, writer.write, 'slightly', 'too', 'long') |
2728 | + writer.write('just', 'write') |
2729 | + |
2730 | + writer.finish() |
2731 | + |
2732 | + writer_file = open(writer.file_iter().next(), 'r+') |
2733 | + |
2734 | + self.assertEqual(writer_file.readline(), 'name,type\n') |
2735 | + self.assertEqual(writer_file.readline(), 'just,write\n') |
2736 | + self.assertFalse(writer_file.readline()) |
2737 | + |
2738 | + writer_file.close() |
2739 | + |
2740 | + def test_split(self): |
2741 | + split = 10 |
2742 | + files = 10 |
2743 | + writer = self.get_writer(split=split) |
2744 | + |
2745 | + for i in xrange(files*(split-1)): # split-1 to account for header line |
2746 | + writer.write('an', 'entry') |
2747 | + |
2748 | + writer.finish() |
2749 | + |
2750 | + gen_count = sum(1 for f in writer.file_iter()) |
2751 | + self.assertEqual(gen_count, files, |
2752 | + 'generated {} files, expected {}' |
2753 | + .format(gen_count, files)) |
2754 | + |
2755 | + for f in writer.file_iter(): |
2756 | + with open(f, 'r+') as wf: |
2757 | + header_line = wf.readline() |
2758 | + header_expected = 'name,type\n' |
2759 | + self.assertEqual(header_line, header_expected, |
2760 | + '{} contained header {}, expected {}' |
2761 | + .format(f, header_line, header_expected)) # check headers |
2762 | + |
2763 | + # check line count |
2764 | + with open(f, 'r+') as wf: |
2765 | + line_count = sum(1 for line in wf) |
2766 | + self.assertEqual(line_count, split, |
2767 | + '{} contained {} lines, expected {}' |
2768 | + .format(f, line_count, split)) |
2769 | + |
2770 | + |
2771 | +if __name__ == '__main__': |
2772 | + unittest.main() |
2773 | + |
2774 | + |
2775 | |
2776 | === renamed file 'src/sextant/tests.py' => 'src/sextant/test_db_api.py' (properties changed: -x to +x) |
2777 | --- src/sextant/tests.py 2014-08-14 15:23:39 +0000 |
2778 | +++ src/sextant/test_db_api.py 2014-10-23 12:33:12 +0000 |
2779 | @@ -1,3 +1,4 @@ |
2780 | +#!/usr/bin/python |
2781 | # ----------------------------------------- |
2782 | # Sextant |
2783 | # Copyright 2014, Ensoft Ltd. |
2784 | @@ -10,56 +11,69 @@ |
2785 | from db_api import Function |
2786 | from db_api import FunctionQueryResult |
2787 | from db_api import SextantConnection |
2788 | -from db_api import Validator |
2789 | +from db_api import validate_query |
2790 | |
2791 | |
2792 | class TestFunctionQueryResults(unittest.TestCase): |
2793 | - def setUp(self): |
2794 | + @classmethod |
2795 | + def setUpClass(cls): |
2796 | # we need to set up the remote database by using the neo4j_input_api |
2797 | - self.remote_url = 'http://ensoft-sandbox:7474' |
2798 | - |
2799 | - self.setter_connection = SextantConnection(self.remote_url) |
2800 | - self.program_1_name = 'testprogram' |
2801 | - self.upload_program = self.setter_connection.new_program(self.program_1_name) |
2802 | - self.upload_program.add_function('func1') |
2803 | - self.upload_program.add_function('func2') |
2804 | - self.upload_program.add_function('func3') |
2805 | - self.upload_program.add_function('func4') |
2806 | - self.upload_program.add_function('func5') |
2807 | - self.upload_program.add_function('func6') |
2808 | - self.upload_program.add_function('func7') |
2809 | - self.upload_program.add_function_call('func1', 'func2') |
2810 | - self.upload_program.add_function_call('func1', 'func4') |
2811 | - self.upload_program.add_function_call('func2', 'func1') |
2812 | - self.upload_program.add_function_call('func2', 'func4') |
2813 | - self.upload_program.add_function_call('func3', 'func5') |
2814 | - self.upload_program.add_function_call('func4', 'func4') |
2815 | - self.upload_program.add_function_call('func4', 'func5') |
2816 | - self.upload_program.add_function_call('func5', 'func1') |
2817 | - self.upload_program.add_function_call('func5', 'func2') |
2818 | - self.upload_program.add_function_call('func5', 'func3') |
2819 | - self.upload_program.add_function_call('func6', 'func7') |
2820 | - |
2821 | - self.upload_program.commit() |
2822 | - |
2823 | - self.one_node_program_name = 'testprogram1' |
2824 | - self.upload_one_node_program = self.setter_connection.new_program(self.one_node_program_name) |
2825 | - self.upload_one_node_program.add_function('lonefunc') |
2826 | - |
2827 | - self.upload_one_node_program.commit() |
2828 | + cls.remote_url = 'http://ensoft-sandbox:7474' |
2829 | + |
2830 | + cls.setter_connection = SextantConnection('ensoft-sandbox', 7474) |
2831 | + |
2832 | + cls.program_1_name = 'testprogram' |
2833 | + cls.one_node_program_name = 'testprogram1' |
2834 | + cls.empty_program_name = 'testprogramblank' |
2835 | + |
2836 | + # if anything failed before, delete programs now |
2837 | + cls.setter_connection.delete_program(cls.program_1_name) |
2838 | + cls.setter_connection.delete_program(cls.one_node_program_name) |
2839 | + cls.setter_connection.delete_program(cls.empty_program_name) |
2840 | + |
2841 | + |
2842 | + cls.upload_program = cls.setter_connection.new_program(cls.program_1_name) |
2843 | + cls.upload_program.add_function('func1') |
2844 | + cls.upload_program.add_function('func2') |
2845 | + cls.upload_program.add_function('func3') |
2846 | + cls.upload_program.add_function('func4') |
2847 | + cls.upload_program.add_function('func5') |
2848 | + cls.upload_program.add_function('func6') |
2849 | + cls.upload_program.add_function('func7') |
2850 | + cls.upload_program.add_call('func1', 'func2') |
2851 | + cls.upload_program.add_call('func1', 'func4') |
2852 | + cls.upload_program.add_call('func2', 'func1') |
2853 | + cls.upload_program.add_call('func2', 'func4') |
2854 | + cls.upload_program.add_call('func3', 'func5') |
2855 | + cls.upload_program.add_call('func4', 'func4') |
2856 | + cls.upload_program.add_call('func4', 'func5') |
2857 | + cls.upload_program.add_call('func5', 'func1') |
2858 | + cls.upload_program.add_call('func5', 'func2') |
2859 | + cls.upload_program.add_call('func5', 'func3') |
2860 | + cls.upload_program.add_call('func6', 'func7') |
2861 | + |
2862 | + cls.upload_program.commit() |
2863 | + |
2864 | + cls.upload_one_node_program = cls.setter_connection.new_program(cls.one_node_program_name) |
2865 | + cls.upload_one_node_program.add_function('lonefunc') |
2866 | + |
2867 | + cls.upload_one_node_program.commit() |
2868 | |
2869 | - self.empty_program_name = 'testprogramblank' |
2870 | - self.upload_empty_program = self.setter_connection.new_program(self.empty_program_name) |
2871 | - |
2872 | - self.upload_empty_program.commit() |
2873 | - |
2874 | - self.getter_connection = SextantConnection(self.remote_url) |
2875 | - |
2876 | - def tearDown(self): |
2877 | - self.setter_connection.delete_program(self.upload_program.program_name) |
2878 | - self.setter_connection.delete_program(self.upload_one_node_program.program_name) |
2879 | - self.setter_connection.delete_program(self.upload_empty_program.program_name) |
2880 | - del(self.setter_connection) |
2881 | + cls.upload_empty_program = cls.setter_connection.new_program(cls.empty_program_name) |
2882 | + |
2883 | + cls.upload_empty_program.commit() |
2884 | + |
2885 | + cls.getter_connection = cls.setter_connection |
2886 | + |
2887 | + |
2888 | + @classmethod |
2889 | + def tearDownClass(cls): |
2890 | + cls.setter_connection.delete_program(cls.upload_program.program_name) |
2891 | + cls.setter_connection.delete_program(cls.upload_one_node_program.program_name) |
2892 | + cls.setter_connection.delete_program(cls.upload_empty_program.program_name) |
2893 | + |
2894 | + cls.setter_connection.close() |
2895 | + del(cls.setter_connection) |
2896 | |
2897 | def test_17_get_call_paths(self): |
2898 | reference1 = FunctionQueryResult(parent_db=None, program_name=self.program_1_name) |
2899 | @@ -134,7 +148,7 @@ |
2900 | |
2901 | def test_08_get_program_names(self): |
2902 | reference = {self.program_1_name, self.one_node_program_name, self.empty_program_name} |
2903 | - self.assertEqual(reference, self.getter_connection.get_program_names()) |
2904 | + self.assertTrue(reference.issubset(self.getter_connection.get_program_names())) |
2905 | |
2906 | |
2907 | def test_11_get_all_functions_called(self): |
2908 | @@ -249,13 +263,13 @@ |
2909 | self.assertIsNone(self.getter_connection.get_call_paths(self.one_node_program_name, 'notafunc', 'notafunc')) |
2910 | |
2911 | def test_10_validator(self): |
2912 | - self.assertFalse(Validator.validate('')) |
2913 | - self.assertTrue(Validator.validate('thisworks')) |
2914 | - self.assertTrue(Validator.validate('th1sw0rks')) |
2915 | - self.assertTrue(Validator.validate('12345')) |
2916 | - self.assertFalse(Validator.validate('this does not work')) |
2917 | - self.assertTrue(Validator.validate('this_does_work')) |
2918 | - self.assertFalse(Validator.validate("'")) # string consisting of a single quote mark |
2919 | + self.assertFalse(validate_query('')) |
2920 | + self.assertTrue(validate_query('thisworks')) |
2921 | + self.assertTrue(validate_query('th1sw0rks')) |
2922 | + self.assertTrue(validate_query('12345')) |
2923 | + self.assertFalse(validate_query('this does not work')) |
2924 | + self.assertTrue(validate_query('this_does_work')) |
2925 | + self.assertFalse(validate_query("'")) # string consisting of a single quote mark |
2926 | |
2927 | if __name__ == '__main__': |
2928 | - unittest.main() |
2929 | \ No newline at end of file |
2930 | + unittest.main() |
2931 | |
2932 | === added file 'src/sextant/test_parser.py' |
2933 | --- src/sextant/test_parser.py 1970-01-01 00:00:00 +0000 |
2934 | +++ src/sextant/test_parser.py 2014-10-23 12:33:12 +0000 |
2935 | @@ -0,0 +1,85 @@ |
2936 | +#!/usr/bin/python |
2937 | +from collections import defaultdict |
2938 | +import unittest |
2939 | +import subprocess |
2940 | + |
2941 | +import objdump_parser as parser |
2942 | + |
2943 | +DUMP_FILE = 'test_resources/parser_test.dump' |
2944 | + |
2945 | +class TestSequence(unittest.TestCase): |
2946 | + def setUp(self): |
2947 | + pass |
2948 | + |
2949 | + def add_function(self, dct, name, typ): |
2950 | + self.assertFalse(name in dct, "duplicate function added: {} into {}".format(name, dct.keys())) |
2951 | + dct[name] = typ |
2952 | + |
2953 | + def add_call(self, dct, caller, callee): |
2954 | + dct[caller].append(callee) |
2955 | + |
2956 | + def do_parse(self, path=DUMP_FILE, sections=['.text'], ignore_ptrs=False): |
2957 | + functions = {} |
2958 | + calls = defaultdict(list) |
2959 | + |
2960 | + # set the Parser to put output in local dictionaries |
2961 | + add_function = lambda n, t: self.add_function(functions, n, t) |
2962 | + add_call = lambda a, b: self.add_call(calls, a, b) |
2963 | + |
2964 | + p = parser.Parser(path, sections=sections, ignore_ptrs=ignore_ptrs, |
2965 | + add_function=add_function, add_call=add_call) |
2966 | + res = p.parse() |
2967 | + |
2968 | + parser.add_function = None |
2969 | + parser.add_call = None |
2970 | + |
2971 | + return res, functions, calls |
2972 | + |
2973 | + |
2974 | + def test_open(self): |
2975 | + self.assertRaises(parser.FileNotFoundError, parser.Parser, file_path='rubbish file') |
2976 | + |
2977 | + def test_functions(self): |
2978 | + # ensure that the correct functions are listed with the correct types |
2979 | + res, funcs, calls = self.do_parse() |
2980 | + |
2981 | + for name, typ in zip(['normal', 'duplicates', 'wierd$name', 'printf', 'func_ptr_3'], |
2982 | + ['normal', 'normal', 'normal', 'stub', 'pointer']): |
2983 | + self.assertTrue(name in funcs, "'{}' not found in function dictionary".format(name)) |
2984 | + self.assertEquals(funcs[name], typ) |
2985 | + |
2986 | + self.assertFalse('__gmon_start__' in funcs, "don't see a function defined in .plt") |
2987 | + |
2988 | + def test_no_ptrs(self): |
2989 | + # ensure that the ignore_ptrs flags is working |
2990 | + res, funcs, calls = self.do_parse(ignore_ptrs=True) |
2991 | + |
2992 | + self.assertFalse('pointer' in funcs.values()) |
2993 | + self.assertEqual(len(calls['normal']), 2) |
2994 | + |
2995 | + |
2996 | + def test_calls(self): |
2997 | + res, funcs, calls = self.do_parse() |
2998 | + |
2999 | + self.assertTrue('normal' in calls['main']) |
3000 | + self.assertTrue('duplicates' in calls['main']) |
3001 | + |
3002 | + normal_calls = sorted(['wierd$name', 'printf', 'func_ptr_3']) |
3003 | + self.assertEquals(sorted(calls['normal']), normal_calls) |
3004 | + |
3005 | + self.assertEquals(calls['duplicates'].count('normal'), 2) |
3006 | + self.assertEquals(calls['duplicates'].count('printf'), 2, |
3007 | + "expected 2 printf calls in {}".format(calls['duplicates'])) |
3008 | + self.assertTrue('func_ptr_4' in calls['duplicates']) |
3009 | + self.assertTrue('func_ptr_5' in calls['duplicates']) |
3010 | + |
3011 | + def test_sections(self): |
3012 | + res, funcs, calls = self.do_parse(sections=['.plt', '.text']) |
3013 | + |
3014 | + # check that we have got rid of the @s in the names |
3015 | + self.assertTrue('@' not in ''.join(funcs.keys()), "check names are extracted correctly") |
3016 | + self.assertTrue('__gmon_start__' in funcs, "see a function defined only in .plt") |
3017 | + |
3018 | + |
3019 | +if __name__ == '__main__': |
3020 | + unittest.main() |
3021 | |
3022 | === added directory 'src/sextant/test_resources' |
3023 | === added file 'src/sextant/test_resources/parser_test' |
3024 | Binary files src/sextant/test_resources/parser_test 1970-01-01 00:00:00 +0000 and src/sextant/test_resources/parser_test 2014-10-23 12:33:12 +0000 differ |
3025 | === added file 'src/sextant/test_resources/parser_test.c' |
3026 | --- src/sextant/test_resources/parser_test.c 1970-01-01 00:00:00 +0000 |
3027 | +++ src/sextant/test_resources/parser_test.c 2014-10-23 12:33:12 +0000 |
3028 | @@ -0,0 +1,57 @@ |
3029 | +// COMMENT |
3030 | +#include<stdio.h> |
3031 | + |
3032 | +static int |
3033 | +normal(int a); |
3034 | + |
3035 | +static int |
3036 | +wierd$name(int a); |
3037 | + |
3038 | +typedef int (*pointer)(int); |
3039 | + |
3040 | +static int |
3041 | +normal(int a) |
3042 | +{ |
3043 | + /* call a normal func, |
3044 | + * a stub and a pointer |
3045 | + */ |
3046 | + pointer ptr = wierd$name; |
3047 | + |
3048 | + wierd$name(a); |
3049 | + printf("%d\n", a); |
3050 | + ptr(a); |
3051 | + |
3052 | + return (a); |
3053 | +} |
3054 | + |
3055 | +static int |
3056 | +wierd$name(int a) |
3057 | +{ |
3058 | + return (a); |
3059 | +} |
3060 | + |
3061 | +static int |
3062 | +duplicates(int a) |
3063 | +{ |
3064 | + pointer ptr1 = wierd$name; |
3065 | + |
3066 | + /* check stubs don't get duplicated */ |
3067 | + printf("first %d\n", a); |
3068 | + printf("second %d\n", a); |
3069 | + |
3070 | + normal(a); |
3071 | + normal(a); |
3072 | + |
3073 | + ptr1(a); |
3074 | + ptr1(a); |
3075 | + |
3076 | + return (a); |
3077 | +} |
3078 | + |
3079 | +int |
3080 | +main(void) |
3081 | +{ |
3082 | + normal(1); |
3083 | + duplicates(1); |
3084 | + return (0); |
3085 | +} |
3086 | |
3087 | === added file 'src/sextant/test_resources/parser_test.dump' |
3088 | --- src/sextant/test_resources/parser_test.dump 1970-01-01 00:00:00 +0000 |
3089 | +++ src/sextant/test_resources/parser_test.dump 2014-10-23 12:33:12 +0000 |
3090 | @@ -0,0 +1,44 @@ |
3091 | +Disassembly of section .init: |
3092 | +080482b4 <_init>: |
3093 | + 80482b8: call 8048350 <__x86.get_pc_thunk.bx> |
3094 | + 80482cd: call 8048300 <__gmon_start__@plt> |
3095 | +Disassembly of section .plt: |
3096 | +080482e0 <printf@plt-0x10>: |
3097 | +080482f0 <printf@plt>: |
3098 | +08048300 <__gmon_start__@plt>: |
3099 | +08048310 <__libc_start_main@plt>: |
3100 | +Disassembly of section .text: |
3101 | +08048320 <_start>: |
3102 | + 804833c: call 8048310 <__libc_start_main@plt> |
3103 | +08048350 <__x86.get_pc_thunk.bx>: |
3104 | +08048360 <deregister_tm_clones>: |
3105 | + 8048386: call *%eax |
3106 | +08048390 <register_tm_clones>: |
3107 | + 80483c3: call *%edx |
3108 | +080483d0 <__do_global_dtors_aux>: |
3109 | + 80483df: call 8048360 <deregister_tm_clones> |
3110 | +080483f0 <frame_dummy>: |
3111 | + 804840f: call *%eax |
3112 | +0804841d <normal>: |
3113 | + 8048430: call 8048458 <wierd$name> |
3114 | + 8048443: call 80482f0 <printf@plt> |
3115 | + 8048451: call *%eax |
3116 | +08048458 <wierd$name>: |
3117 | +08048460 <duplicates>: |
3118 | + 804847b: call 80482f0 <printf@plt> |
3119 | + 804848e: call 80482f0 <printf@plt> |
3120 | + 8048499: call 804841d <normal> |
3121 | + 80484a4: call 804841d <normal> |
3122 | + 80484b2: call *%eax |
3123 | + 80484bd: call *%eax |
3124 | +080484c4 <main>: |
3125 | + 80484d4: call 804841d <normal> |
3126 | + 80484e0: call 8048460 <duplicates> |
3127 | +080484f0 <__libc_csu_init>: |
3128 | + 80484f6: call 8048350 <__x86.get_pc_thunk.bx> |
3129 | + 804850e: call 80482b4 <_init> |
3130 | + 804853b: call *-0xf8(%ebx,%edi,4) |
3131 | +08048560 <__libc_csu_fini>: |
3132 | +Disassembly of section .fini: |
3133 | +08048564 <_fini>: |
3134 | + 8048568: call 8048350 <__x86.get_pc_thunk.bx> |
3135 | |
3136 | === added file 'src/sextant/test_sshmanager.py' |
3137 | --- src/sextant/test_sshmanager.py 1970-01-01 00:00:00 +0000 |
3138 | +++ src/sextant/test_sshmanager.py 2014-10-23 12:33:12 +0000 |
3139 | @@ -0,0 +1,72 @@ |
3140 | +#!/usr/bin/python3 |
3141 | +import unittest |
3142 | +import sshmanager |
3143 | +import sshmanager |
3144 | +import os |
3145 | +sshmanager.TMP_DIR = '/home/benhutc/obj/csvload/src/sextant/test_resources/tmp' |
3146 | + |
3147 | + |
3148 | +class TestSequence(unittest.TestCase): |
3149 | + def setUp(self): |
3150 | + self.manager = None |
3151 | + |
3152 | + def tearDown(self): |
3153 | + if self.manager: |
3154 | + self.manager.close() |
3155 | + self.manager = None |
3156 | + |
3157 | + def get_manager(self, local_port=9643, remote_host='localhost', |
3158 | + remote_port=9643, ssh_user=None): |
3159 | + return sshmanager.SSHManager(local_port, remote_host, remote_port, ssh_user) |
3160 | + |
3161 | + def test_init(self): |
3162 | + self.assertRaises(ValueError, self.get_manager, local_port='invalid port') |
3163 | + self.assertRaises(ValueError, self.get_manager, remote_port='invalid port') |
3164 | + |
3165 | + def test_connect(self): |
3166 | + # make a connection to localhost and ensure that tmp is created |
3167 | + self.manager = self.get_manager() |
3168 | + self.assertTrue(os.path.isdir(self.manager._tmp_dir)) |
3169 | + self.manager.close() |
3170 | + self.assertFalse(os.path.isdir(self.manager._tmp_dir)) |
3171 | + self.manager = None |
3172 | + |
3173 | + # check connecion failure |
3174 | + self.assertRaises(sshmanager.SSHConnectionError, self.get_manager, remote_host='invalid host') |
3175 | + |
3176 | + def test_files(self): |
3177 | + genuine_file = 'test_resources/parser_test.c' |
3178 | + genuine_file2 = 'test_resources/parser_test' |
3179 | + absent_file = 'absent_file' |
3180 | + |
3181 | + self.manager = self.get_manager() |
3182 | + # check sending no files fails |
3183 | + self.assertRaises(ValueError, self.manager.send_to_tmp_dir, []) |
3184 | + # and sending an non-existent file |
3185 | + self.assertRaises(ValueError, self.manager.send_to_tmp_dir, [absent_file, genuine_file]) |
3186 | + |
3187 | + self.manager.send_to_tmp_dir([genuine_file, genuine_file2]) |
3188 | + self.assertTrue(os.path.isfile(os.path.join(self.manager._tmp_dir, genuine_file.split('/')[-1]))) |
3189 | + self.assertTrue(os.path.isfile(os.path.join(self.manager._tmp_dir, genuine_file2.split('/')[-1]))) |
3190 | + |
3191 | + self.manager.remove_from_tmp_dir([genuine_file, genuine_file2]) |
3192 | + self.assertFalse(os.path.isfile(os.path.join(self.manager._tmp_dir, |
3193 | + genuine_file.split('/')[-1]))) |
3194 | + self.assertFalse(os.path.isfile(os.path.join(self.manager._tmp_dir, |
3195 | + genuine_file2.split('/')[-1]))) |
3196 | + |
3197 | + |
3198 | + self.manager.close() |
3199 | + self.manager = None |
3200 | + |
3201 | + |
3202 | +if __name__ == '__main__': |
3203 | + # no coverage for: |
3204 | + # specifying ssh user |
3205 | + # scp failure |
3206 | + # an error in closing the ssh connection |
3207 | + # another error in closing the ssh connection |
3208 | + # mkdir failure |
3209 | + # rmdir failure |
3210 | + unittest.main() |
3211 | + |
3212 | |
3213 | === modified file 'src/sextant/update_db.py' |
3214 | --- src/sextant/update_db.py 2014-09-29 14:01:39 +0000 |
3215 | +++ src/sextant/update_db.py 2014-10-23 12:33:12 +0000 |
3216 | @@ -5,72 +5,106 @@ |
3217 | # ----------------------------------------- |
3218 | # Given a program file to upload, or a program name to delete from the server, does the right thing. |
3219 | |
3220 | +from __future__ import print_function |
3221 | + |
3222 | __all__ = ("upload_program", "delete_program") |
3223 | |
3224 | -from .db_api import SextantConnection, Validator |
3225 | -from .objdump_parser import get_parsed_objects |
3226 | +from .db_api import SextantConnection |
3227 | +from .sshmanager import SSHConnectionError |
3228 | +from .objdump_parser import Parser, run_objdump |
3229 | from os import path |
3230 | +from time import time |
3231 | +import subprocess |
3232 | +import sys |
3233 | |
3234 | import logging |
3235 | |
3236 | - |
3237 | -def upload_program(user_name, file_path, db_url, display_url='', |
3238 | - alternative_name=None, not_object_file=False): |
3239 | - """ |
3240 | - Uploads a program to the remote database. |
3241 | - |
3242 | - Raises requests.exceptions.ConnectionError if the server didn't exist. |
3243 | - Raises IOError if file_path doesn't correspond to a file. |
3244 | - Raises ValueError if the desired alternative_name (or the default, if no |
3245 | - alternative_name was specified) already exists in the database. |
3246 | - :param file_path: the path to the local file we wish to upload |
3247 | - :param db_url: the URL of the database (eg. http://localhost:7474) |
3248 | - :param display_url: alternative URL to display instead of db_url |
3249 | - :param alternative_name: a name to give the program to override the default |
3250 | - :param object_file: bool(the file is an objdump text output file, rather than a compiled binary) |
3251 | - |
3252 | - """ |
3253 | - |
3254 | - if not display_url: |
3255 | - display_url = db_url |
3256 | - |
3257 | - # if no name is specified, use the form "<username>-<binary name>" |
3258 | - name = alternative_name or (user_name + '-' + path.split(file_path)[-1]) |
3259 | - |
3260 | - connection = SextantConnection(db_url) |
3261 | - |
3262 | - program_names = connection.get_program_names() |
3263 | - if Validator.sanitise(name) in program_names: |
3264 | - raise ValueError("There is already a program with name {}; " |
3265 | - "please delete the previous one with the same name " |
3266 | - "and retry, or rename the input file.".format(name)) |
3267 | - |
3268 | - parsed_objects = get_parsed_objects(filepath=file_path, |
3269 | - sections_to_view=['.text'], |
3270 | - not_object_file=not_object_file, |
3271 | - ignore_function_pointers=False) |
3272 | - |
3273 | - logging.info('Objdump has parsed!') |
3274 | - |
3275 | - program_representation = connection.new_program(Validator.sanitise(name)) |
3276 | - |
3277 | - for obj in parsed_objects: |
3278 | - for called in obj.what_do_i_call: |
3279 | - if not program_representation.add_function_call(obj.name, called[-1]): # called is a tuple (address, name) |
3280 | - logging.error('Validation error: {} calling {}'.format(obj.name, called[-1])) |
3281 | - |
3282 | - logging.info('Sending {} named objects to server {}...'.format(len(parsed_objects), display_url)) |
3283 | - program_representation.commit() |
3284 | - logging.info('Successfully added {}.'.format(name)) |
3285 | - |
3286 | - |
3287 | -def delete_program(program_name, db_url): |
3288 | - """ |
3289 | - Deletes a program with the specified name from the database. |
3290 | - :param program_name: the name of the program to delete |
3291 | - :param db_url: the URL of the database (eg. http://localhost:7474) |
3292 | - :return: bool(success) |
3293 | - """ |
3294 | - connection = SextantConnection(db_url) |
3295 | +def upload_program(connection, user_name, file_path, program_name=None, |
3296 | + not_object_file=False): |
3297 | + """ |
3298 | + Upload a program's functions and call graph to the database. |
3299 | + |
3300 | + Arguments: |
3301 | + connection: |
3302 | + The SextantConnection object that manages the database connection. |
3303 | + user_name: |
3304 | + The user name of the user uploading the program. |
3305 | + file_path: |
3306 | + The path to either: the output of objdump (if not_object_file is |
3307 | + True) OR to a binary file if (not_object_file is False). |
3308 | + program_name: |
3309 | + An optional name to give the program in the database, if not |
3310 | + specified then <user_name>-<file name> will be used. |
3311 | + not_object_file: |
3312 | + Flag controlling whether file_path is pointing to a dump file or |
3313 | + a binary file. |
3314 | + """ |
3315 | + if not connection._ssh: |
3316 | + raise SSHConnectionError('An SSH connection is required for ' |
3317 | + 'program upload.') |
3318 | + |
3319 | + if not program_name: |
3320 | + file_no_ext = path.basename(file_path).split('.')[0] |
3321 | + program_name = '{}-{}'.format(user_name, file_no_ext) |
3322 | + |
3323 | + |
3324 | + if program_name in connection.get_program_names(): |
3325 | + raise ValueError('A program with name `{}` already exists in the database' |
3326 | + .format(program_name)) |
3327 | + |
3328 | + |
3329 | + print('Uploading `{}` to the database. ' |
3330 | + 'This may take some time for larger programs.' |
3331 | + .format(program_name)) |
3332 | + start = time() |
3333 | + |
3334 | + if not not_object_file: |
3335 | + print('Generating dump file...', end='') |
3336 | + sys.stdout.flush() |
3337 | + file_path, file_object = run_objdump(file_path) |
3338 | + print('done.') |
3339 | + else: |
3340 | + file_object = None |
3341 | + |
3342 | + # Make parser and wire to DBprogram. |
3343 | + with connection.new_program(program_name) as program: |
3344 | + |
3345 | + def start_parser(program): |
3346 | + print('Parsing dump file...', end='') |
3347 | + sys.stdout.flush() |
3348 | + |
3349 | + def finish_parser(parser, program): |
3350 | + # Callback to make sure the program's csv files are flushed when |
3351 | + # the parser completes. |
3352 | + program.func_writer.finish() |
3353 | + program.call_writer.finish() |
3354 | + |
3355 | + print('done: {} functions and {} calls.' |
3356 | + .format(parser.function_count, parser.call_count)) |
3357 | + |
3358 | + parser = Parser(file_path = file_path, file_object = file_object, |
3359 | + sections=[], |
3360 | + add_function = program.add_function, |
3361 | + add_call = program.add_call, |
3362 | + started=lambda parser: start_parser(program), |
3363 | + finished=lambda parser: finish_parser(parser, program)) |
3364 | + parser.parse() |
3365 | + |
3366 | + program.commit() |
3367 | + |
3368 | + end = time() |
3369 | + print('Finished in {:.2f}s.'.format(end-start)) |
3370 | + |
3371 | + |
3372 | +def delete_program(connection, program_name): |
3373 | + """ |
3374 | + Remove the specified program from the database. |
3375 | + |
3376 | + Arguments: |
3377 | + connection: |
3378 | + The SextantConnection object managing the database connection. |
3379 | + program_name: |
3380 | + The name of the program to remove from the database. |
3381 | + """ |
3382 | connection.delete_program(program_name) |
3383 | - print('Successfully deleted {}.'.format(program_name)) |
3384 | + |
3385 | |
3386 | === modified file 'src/sextant/web/server.py' |
3387 | --- src/sextant/web/server.py 2014-10-03 11:47:52 +0000 |
3388 | +++ src/sextant/web/server.py 2014-10-23 12:33:12 +0000 |
3389 | @@ -26,7 +26,8 @@ |
3390 | |
3391 | from cgi import escape # deprecated in Python 3 in favour of html.escape, but we're stuck on Python 2 |
3392 | |
3393 | -database_url = None # the URL to access the database instance |
3394 | +# global SextantConnection object which deals with the port forwarding |
3395 | +CONNECTION = None |
3396 | |
3397 | RESPONSE_CODE_OK = 200 |
3398 | RESPONSE_CODE_BAD_REQUEST = 400 |
3399 | @@ -67,25 +68,6 @@ |
3400 | |
3401 | class SVGRenderer(Resource): |
3402 | |
3403 | - def error_creating_neo4j_connection(self, failure): |
3404 | - self.write("Error creating Neo4J connection: %s\n") % failure.getErrorMessage() |
3405 | - |
3406 | - @staticmethod |
3407 | - def create_neo4j_connection(): |
3408 | - return db_api.SextantConnection(database_url) |
3409 | - |
3410 | - @staticmethod |
3411 | - def check_program_exists(connection, name): |
3412 | - return connection.check_program_exists(name) |
3413 | - |
3414 | - @staticmethod |
3415 | - def get_whole_program(connection, name): |
3416 | - return connection.get_whole_program(name) |
3417 | - |
3418 | - @staticmethod |
3419 | - def get_functions_calling(connection, progname, funcname): |
3420 | - return connection.get_all_functions_calling(progname, funcname) |
3421 | - |
3422 | @staticmethod |
3423 | def get_plot(program, suppress_common_functions=False, remove_self_calls=False): |
3424 | graph_dot = export.ProgramConverter.to_dot(program, suppress_common_functions, |
3425 | @@ -111,7 +93,7 @@ |
3426 | res_msg = None # set this in the logic |
3427 | |
3428 | # |
3429 | - # Get program name and database connection, check if program exists |
3430 | + # Check if provided program name exists |
3431 | # |
3432 | |
3433 | name = args.get('program_name', [None])[0] |
3434 | @@ -121,16 +103,7 @@ |
3435 | res_msg = "Supply 'program_name' parameter." |
3436 | |
3437 | if res_code is RESPONSE_CODE_OK: |
3438 | - try: |
3439 | - conn = yield deferToThread(self.create_neo4j_connection) |
3440 | - except requests.exceptions.ConnectionError: |
3441 | - res_code = RESPONSE_CODE_BAD_GATEWAY |
3442 | - res_fmt = "Could not reach Neo4j server at {}" |
3443 | - res_msg = res_fmt.format(database_url) |
3444 | - conn = None |
3445 | - |
3446 | - if res_code is RESPONSE_CODE_OK: |
3447 | - exists = yield deferToThread(self.check_program_exists, conn, name) |
3448 | + exists = yield deferToThread(CONNECTION.check_program_exists, name) |
3449 | if not exists: |
3450 | res_code = RESPONSE_CODE_NOT_FOUND |
3451 | res_fmt = "Program {} not found in database." |
3452 | @@ -146,28 +119,23 @@ |
3453 | # look for in request.args, both tuples |
3454 | queries = { |
3455 | 'whole_program': ( |
3456 | - self.get_whole_program, |
3457 | - (conn, name), |
3458 | + CONNECTION.get_whole_program, |
3459 | () |
3460 | ), |
3461 | 'functions_calling': ( |
3462 | - self.get_functions_calling, |
3463 | - (conn, name), |
3464 | + CONNECTION.get_all_functions_calling, |
3465 | ('func1',) |
3466 | ), |
3467 | 'functions_called_by': ( |
3468 | - conn.get_all_functions_called, |
3469 | - (name,), |
3470 | + CONNECTION.get_all_functions_called, |
3471 | ('func1',) |
3472 | ), |
3473 | 'all_call_paths': ( |
3474 | - conn.get_call_paths, |
3475 | - (name,), |
3476 | + CONNECTION.get_call_paths, |
3477 | ('func1', 'func2') |
3478 | ), |
3479 | 'shortest_call_path': ( |
3480 | - conn.get_shortest_path_between_functions, |
3481 | - (name,), |
3482 | + CONNECTION.get_shortest_path_between_functions, |
3483 | ('func1', 'func2') |
3484 | ) |
3485 | } |
3486 | @@ -186,7 +154,7 @@ |
3487 | |
3488 | # extract any required keyword arguments from request.args |
3489 | if res_code is RESPONSE_CODE_OK: |
3490 | - fn, known_args, kwargs = query |
3491 | + fn, kwargs = query |
3492 | |
3493 | # all args will be strings - use None to indicate missing argument |
3494 | req_args = tuple(args.get(kwarg, [None])[0] for kwarg in kwargs) |
3495 | @@ -202,9 +170,8 @@ |
3496 | # if we are okay here we have a valid query with all required arguments |
3497 | if res_code is RESPONSE_CODE_OK: |
3498 | try: |
3499 | - all_args = known_args + req_args |
3500 | program = yield defer_to_thread_with_timeout(render_timeout, fn, |
3501 | - *all_args) |
3502 | + name, *req_args) |
3503 | except defer.CancelledError: |
3504 | # the timeout has fired and cancelled the request |
3505 | res_code = RESPONSE_CODE_BAD_REQUEST |
3506 | @@ -247,16 +214,12 @@ |
3507 | class GraphProperties(Resource): |
3508 | |
3509 | @staticmethod |
3510 | - def _get_connection(): |
3511 | - return db_api.SextantConnection(database_url) |
3512 | - |
3513 | - @staticmethod |
3514 | - def _get_program_names(connection): |
3515 | - return connection.get_program_names() |
3516 | - |
3517 | - @staticmethod |
3518 | - def _get_function_names(connection, program_name): |
3519 | - return connection.get_function_names(program_name) |
3520 | + def _get_program_names(): |
3521 | + return CONNECTION.get_program_names() |
3522 | + |
3523 | + @staticmethod |
3524 | + def _get_function_names(program_name): |
3525 | + return CONNECTION.get_function_names(program_name) |
3526 | |
3527 | @defer.inlineCallbacks |
3528 | def _render_GET(self, request): |
3529 | @@ -269,18 +232,9 @@ |
3530 | |
3531 | query = request.args['query'][0] |
3532 | |
3533 | - try: |
3534 | - neo4j_connection = yield deferToThread(self._get_connection) |
3535 | - except Exception: |
3536 | - request.setResponseCode(502) # Bad Gateway |
3537 | - request.write("Could not reach Neo4j server at {}.".format(database_url)) |
3538 | - request.finish() |
3539 | - defer.returnValue(None) |
3540 | - neo4j_connection = None # just to silence the "referenced before assignment" warnings |
3541 | - |
3542 | if query == 'programs': |
3543 | request.setHeader("content-type", "application/json") |
3544 | - prognames = yield deferToThread(self._get_program_names, neo4j_connection) |
3545 | + prognames = yield deferToThread(self._get_program_names) |
3546 | request.write(json.dumps(list(prognames))) |
3547 | request.finish() |
3548 | defer.returnValue(None) |
3549 | @@ -294,7 +248,7 @@ |
3550 | defer.returnValue(None) |
3551 | program_name = request.args['program_name'][0] |
3552 | |
3553 | - funcnames = yield deferToThread(self._get_function_names, neo4j_connection, program_name) |
3554 | + funcnames = yield deferToThread(self._get_function_names, program_name) |
3555 | if funcnames is None: |
3556 | request.setResponseCode(404) |
3557 | request.setHeader("content-type", "text/plain") |
3558 | @@ -319,10 +273,12 @@ |
3559 | return NOT_DONE_YET |
3560 | |
3561 | |
3562 | -def serve_site(input_database_url='http://localhost:7474', port=2905): |
3563 | - |
3564 | - global database_url |
3565 | - database_url = input_database_url |
3566 | +def serve_site(connection, port): |
3567 | + global CONNECTION |
3568 | + |
3569 | + CONNECTION = connection |
3570 | + |
3571 | + |
3572 | # serve static directory at root |
3573 | root = File(os.path.join(environment.RESOURCES_DIR, 'sextant', 'web')) |
3574 |