Merge lp:~ben-hutchings/ensoft-sextant/csv-upload into lp:ensoft-sextant

Proposed by Ben Hutchings
Status: Merged
Approved by: Robert
Approved revision: 62
Merged at revision: 30
Proposed branch: lp:~ben-hutchings/ensoft-sextant/csv-upload
Merge into: lp:ensoft-sextant
Diff against target: 3573 lines (+1945/-1015)
16 files modified
src/sextant/__main__.py (+54/-232)
src/sextant/csvwriter.py (+152/-0)
src/sextant/db_api.py (+587/-302)
src/sextant/export.py (+3/-3)
src/sextant/objdump_parser.py (+302/-262)
src/sextant/query.py (+29/-31)
src/sextant/sshmanager.py (+278/-0)
src/sextant/test_all.sh (+4/-0)
src/sextant/test_csvwriter.py (+89/-0)
src/sextant/test_db_api.py (+68/-54)
src/sextant/test_parser.py (+85/-0)
src/sextant/test_resources/parser_test.c (+57/-0)
src/sextant/test_resources/parser_test.dump (+44/-0)
src/sextant/test_sshmanager.py (+72/-0)
src/sextant/update_db.py (+96/-62)
src/sextant/web/server.py (+25/-69)
To merge this branch: bzr merge lp:~ben-hutchings/ensoft-sextant/csv-upload
Reviewer Review Type Date Requested Status
Robert Approve
Review via email: mp+239356@code.launchpad.net

Commit message

Programs now upload by first being parsed into csv files, then uploaded from these to the database. This is _significantly_ faster for large programs.

Furthermore, the structure of the program nodes in the database has been changed - whereas before they were unlabelled nodes with type 'program', they are now associated with the 'program' label (the database partitions on label - so this labelling keeps programs distinct from the functions). All queries in sextant have been updated to reflect this.

New module sshmanager handles the ssh connection to the database server.
New module csvwriter deals with the nuts and bolts of the csv files.

Description of the change

Programs now upload by first being parsed into csv files, then uploaded from these to the database. This is _significantly_ faster for large programs.

Furthermore, the structure of the program nodes in the database has been changed - whereas before they were unlabelled nodes with type 'program', they are now associated with the 'program' label (the database partitions on label - so this labelling keeps programs distinct from the functions). All queries in sextant have been updated to reflect this.

New module sshmanager handles the ssh connection to the database server.
New module csvwriter deals with the nuts and bolts of the csv files.

To post a comment you must log in.
Revision history for this message
Robert (rjwills) :
review: Approve

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1=== modified file 'src/sextant/__main__.py'
2--- src/sextant/__main__.py 2014-10-03 13:00:52 +0000
3+++ src/sextant/__main__.py 2014-10-23 12:33:12 +0000
4@@ -9,7 +9,6 @@
5
6 import io
7 import sys
8-import random
9 import socket
10 import logging
11 import logging.config
12@@ -28,10 +27,12 @@
13 from . import db_api
14 from . import update_db
15 from . import environment
16+from . import sshmanager
17
18 config = environment.load_config()
19
20
21+
22 def _displayable_url(args):
23 """
24 Return the URL specified by the user for Sextant to look at.
25@@ -56,7 +57,7 @@
26
27 # Beginning of functions which handle the actual invocation of Sextant
28
29-def _start_web(args):
30+def _start_web(connection, args):
31 # Don't import at top level - makes twisted dependency semi-optional,
32 # allowing non-web functionality to work with Python 3.
33 if sys.version_info[0] == 2:
34@@ -68,12 +69,12 @@
35 logging.info("Serving site on port {}".format(args.port))
36
37 # server is .web.server, imported a couple of lines ago
38- server.serve_site(input_database_url=args.remote_neo4j, port=args.port)
39-
40-
41-def _audit(args):
42+ server.serve_site(connection, args.port)
43+
44+
45+def _audit(connection, args):
46 try:
47- audited = query.audit(args.remote_neo4j)
48+ audited = query.audit(connection)
49 except requests.exceptions.ConnectionError as e:
50 msg = 'Connection error to server {url}: {exception}'
51 logging.error(msg.format(url=_displayable_url(args), exception=e))
52@@ -87,8 +88,8 @@
53 titles = ("Name", "#Func", "Uploader", "User-ID", "Upload Date")
54 colminlens = (len(entry) for entry in titles)
55 # maximum lengths to avoid one entry from throwing the whole table
56- # date format is <YYYY:MM:DD HH:MM:SS.UUUUUU> = 26 characters
57- COLMAXLENS = (25, 5, 25, 10, 26)
58+ # date format is <YYYY-MM-DD HH:MM:SS> = 19 characters
59+ COLMAXLENS = (25, 6, 25, 10, 19)
60
61 # make a table of the strings of each data entry we will display
62 text = [map(str, (p.program_name, p.number_of_funcs,
63@@ -120,7 +121,7 @@
64 print('\n'.join(st.format(*pentry) for pentry in text))
65
66
67-def _add_program(args):
68+def _add_program(connection, args):
69 try:
70 alternative_name = args.name_in_db[0]
71 except TypeError:
72@@ -131,12 +132,11 @@
73 # unsupplied
74
75 try:
76- update_db.upload_program(user_name=getpass.getuser(),
77- file_path=args.input_file,
78- db_url=args.remote_neo4j,
79- alternative_name=alternative_name,
80- not_object_file=not_object_file,
81- display_url=_displayable_url(args))
82+ update_db.upload_program(connection,
83+ getpass.getuser(),
84+ args.input_file,
85+ alternative_name,
86+ not_object_file)
87 except requests.exceptions.ConnectionError as e:
88 msg = 'Connection error to server {}: {}'
89 logging.error(msg.format(_displayable_url(args), e))
90@@ -147,41 +147,41 @@
91 logging.error('Input file {} was not found.'.format(args.input_file[0]))
92 logging.error(e)
93 logging.debug(e, exc_info=True)
94- except ValueError as e:
95+ except (ValueError, sshmanager.SSHConnectionError) as e:
96 logging.error(e)
97
98
99-def _delete_program(namespace):
100- update_db.delete_program(namespace.program_name,
101- namespace.remote_neo4j)
102-
103-
104-def _make_query(namespace):
105+def _delete_program(connection, args):
106+ update_db.delete_program(connection, args.program_name)
107+
108+
109+def _make_query(connection, args):
110 arg1 = None
111 arg2 = None
112 try:
113- arg1 = namespace.funcs[0]
114- arg2 = namespace.funcs[1]
115+ arg1 = args.funcs[0]
116+ arg2 = args.funcs[1]
117 except TypeError:
118 pass
119 except IndexError:
120 pass
121
122 try:
123- program_name = namespace.program[0]
124+ program_name = args.program[0]
125 except TypeError:
126 program_name = None
127
128 try:
129- suppress_common = namespace.suppress_common[0]
130+ suppress_common = args.suppress_common[0]
131 except TypeError:
132 suppress_common = False
133
134- query.query(remote_neo4j=namespace.remote_neo4j,
135- display_neo4j=_displayable_url(namespace),
136- input_query=namespace.query,
137+ query.query(remote_neo4j=args.remote_neo4j,
138+ display_neo4j=_displayable_url(args),
139+ input_query=args.query,
140 program_name=program_name,
141- argument_1=arg1, argument_2=arg2,
142+ argument_1=arg1,
143+ argument_2=arg2,
144 suppress_common=suppress_common)
145
146 # End of functions which invoke Sextant
147@@ -197,8 +197,10 @@
148
149 """
150
151- argumentparser = argparse.ArgumentParser(prog='sextant', usage='sextant', description="Invoke part of the SEXTANT program")
152- subparsers = argumentparser.add_subparsers(title="subcommands")
153+ ap = argparse.ArgumentParser(prog='sextant',
154+ usage='sextant',
155+ description="Invoke part of the SEXTANT program")
156+ subparsers = ap.add_subparsers(title="subcommands")
157
158 #set what will be defined as a "common function"
159 db_api.set_common_cutoff(config.common_cutoff)
160@@ -257,10 +259,9 @@
161 parsers[key].add_argument('--remote-neo4j', metavar="URL",
162 help="URL of neo4j server", type=str,
163 default=config.remote_neo4j)
164- parsers[key].add_argument('--use-ssh-tunnel', metavar="BOOL", type=str,
165- help="whether to SSH into the remote server,"
166- "True/False",
167- default=str(config.use_ssh_tunnel))
168+ parsers[key].add_argument('--no-ssh-tunnel',
169+ help='Disable ssh tunnelling. Prevents program upload.',
170+ action='store_true')
171 parsers[key].add_argument('--ssh-user', metavar="NAME", type=str,
172 help="username to use as remote SSH name",
173 default=str(config.ssh_user))
174@@ -273,207 +274,28 @@
175
176 # parse the arguments
177
178- return argumentparser.parse_args()
179-
180-
181-def _start_tunnel(local_port, remote_host, remote_port, ssh_user=''):
182- """
183- Creates an SSH port-forward.
184-
185- This will result in localhost:local_port appearing to be
186- remote_host:remote_port.
187-
188- :param local_port: integer port number to open at localhost
189- :param remote_host: string address of remote host (no port number)
190- :param remote_port: port to 'open' on the remote host
191- :param ssh_user: user to log in on the remote_host as
192-
193- """
194-
195- if not (isinstance(local_port, int) and local_port > 0):
196- raise ValueError(
197- 'Local port {} must be a positive integer.'.format(local_port))
198- if not (isinstance(remote_port, int) and remote_port > 0):
199- raise ValueError(
200- 'Remote port {} must be a positive integer.'.format(remote_port))
201-
202- logging.debug('Starting SSH tunnel...')
203-
204- # this cmd string will be .format()ed in a few lines' time
205- cmd = ['ssh']
206-
207- if ssh_user:
208- # ssh -l {user} ... sets the remote login username
209- cmd += ['-l', ssh_user]
210-
211- # -L localport:localhost:remoteport forwards the port
212- # -M makes SSH able to accept slave connections
213- # -S sets the location of a control socket (in this case, sextant-controller
214- # with a unique identifier appended, just in case we run sextant twice
215- # simultaneously), so we know how to close the port again
216- # -f goes into background; -N does not execute a remote command;
217- # -T says to remote host that we don't want a text shell.
218- cmd += ['-M',
219- '-S', 'sextantcontroller{tunnel_id}'.format(tunnel_id=local_port),
220- '-fNT',
221- '-L', '{0}:localhost:{1}'.format(local_port, remote_port),
222- remote_host]
223-
224- logging.debug('Running {}'.format(' '.join(cmd)))
225-
226- exit_code = subprocess.call(cmd)
227- if exit_code:
228- raise OSError('SSH setup failed with error {}'.format(exit_code))
229-
230- logging.debug('SSH tunnel created.')
231-
232-
233-def _stop_tunnel(local_port, remote_host):
234- """
235- Tear down an SSH port-forward which was previously set up with start_tunnel.
236-
237- We use local_port as an identifier.
238- :param local_port: the port on localhost we are using as the entrypoint
239- :param remote_host: remote host we tunnelled into
240-
241- """
242-
243- logging.debug('Shutting down SSH tunnel...')
244-
245- # ssh -O sends a command to the slave specified in -S
246- cmd = ['ssh',
247- '-S', 'sextantcontroller{}'.format(local_port),
248- '-O', 'exit',
249- '-q', # for quiet
250- remote_host]
251-
252- # SSH has a bug on some systems which causes it to ignore the -q flag
253- # meaning it prints "Exit request sent." to stderr.
254- # To avoid this, we grab stderr temporarily, and see if it's that string;
255- # if it is, suppress it.
256- pr = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
257- stdout, stderr = pr.communicate()
258- if stderr.rstrip() != 'Exit request sent.':
259- print(stderr, file=sys.stderr)
260- if pr.returncode == 0:
261- logging.debug('Shut down successfully.')
262- else:
263- logging.warning(
264- 'SSH tunnel shutdown returned error code {}'.format(pr.returncode))
265- logging.warning(stderr)
266-
267-
268-def _is_port_used(port):
269- """
270- Checks with the OS to see whether a port is open.
271-
272- Beware: port is passed directly to the shell. Make sure it is an integer.
273- We raise ValueError if it is not.
274- :param port: integer port to check for openness
275- :return: bool(port is in use)
276-
277- """
278-
279- # we follow http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python
280- if not (isinstance(port, int) and port > 0):
281- raise ValueError('port {} must be a positive integer.'.format(port))
282-
283- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
284- try:
285- sock.bind(('127.0.0.1', port))
286- except socket.error as e:
287- if e.errno == 98: # Address already in use
288- return True
289- raise
290-
291- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
292-
293- return False # that is, the port is not used
294-
295-
296-def _get_unused_port():
297- """
298- Returns a port number between 10000 and 50000 which is not currently open.
299-
300- """
301-
302- keep_going = True
303- while keep_going:
304- portnum = random.randint(10000, 50000)
305- keep_going = _is_port_used(portnum)
306- return portnum
307-
308+ return ap.parse_args()
309
310 def _get_host_and_port(url):
311- """Given a URL as http://host:port, returns (host, port)."""
312- parsed = parse.urlparse(url)
313- return (parsed.hostname, parsed.port)
314-
315-
316-def _is_localhost(host, port):
317- """
318- Checks whether a host is an alias to localhost.
319-
320- Raises socket.gaierror if the host was not found.
321-
322- """
323-
324- addr = socket.getaddrinfo(host, port)[0][4][0]
325-
326- return addr in ('127.0.0.1', '::1')
327+ """Given a URL as http://host:port, returns (host, port)."""
328+ parsed = parse.urlparse(url)
329+ return (parsed.hostname, parsed.port)
330
331
332 def main():
333 args = parse_arguments()
334-
335- if args.use_ssh_tunnel.lower() == 'true':
336- localport = _get_unused_port()
337-
338- remotehost, remoteport = _get_host_and_port(args.remote_neo4j)
339-
340- try:
341- is_loc = _is_localhost(remotehost, remoteport)
342- except socket.gaierror:
343- logging.error('Server {} not found.'.format(remotehost))
344- return
345-
346- if is_loc:
347- # we are attempting to connect to localhost anyway, so we won't
348- # bother to SSH to it.
349- # There may be some ways the user can trick us into trying to SSH
350- # to localhost anyway, but this will do as a first pass.
351- # SSHing to localhost is undesirable because on my test computer,
352- # we get 'connection refused' if we try.
353- args.func(args)
354-
355- else: # we need to SSH
356- try:
357- _start_tunnel(localport, remotehost, remoteport,
358- ssh_user=args.ssh_user)
359- except OSError as e:
360- logging.error(str(e))
361- return
362- except KeyboardInterrupt:
363- logging.info('Halting because of user interrupt.')
364- return
365-
366- try:
367- args.display_neo4j = args.remote_neo4j
368- args.remote_neo4j = 'http://localhost:{}'.format(localport)
369- args.func(args)
370- except KeyboardInterrupt:
371- # this probably happened because we were running Sextant Web
372- # and Ctrl-C'ed out of it
373- logging.info('Keyboard interrupt detected. Halting.')
374- pass
375-
376- finally:
377- _stop_tunnel(localport, remotehost)
378-
379- else: # no need to set up the ssh, just run sextant
380- args.func(args)
381-
382-
383+ remotehost, remoteport = _get_host_and_port(args.remote_neo4j)
384+ no_ssh_tunnel = args.no_ssh_tunnel
385+ connection = None
386+
387+ try:
388+ conn_args = (remotehost, remoteport, no_ssh_tunnel)
389+ with db_api.SextantConnection(*conn_args) as connection:
390+ args.func(connection, args)
391+ except sshmanager.SSHConnectionError as e:
392+ print(e.message)
393+
394+
395 if __name__ == '__main__':
396 main()
397
398
399=== added file 'src/sextant/csvwriter.py'
400--- src/sextant/csvwriter.py 1970-01-01 00:00:00 +0000
401+++ src/sextant/csvwriter.py 2014-10-23 12:33:12 +0000
402@@ -0,0 +1,152 @@
403+import logging
404+
405+"""
406+Provide a class for writing to row-limited csv files.
407+"""
408+__all__ = ('CSVWriter',)
409+
410+
411+class CSVWriter(object):
412+ """
413+ Write to csv files, automatically opening new ones at row maximum.
414+
415+ Provides a write(*args) method which will add a row to the currently open
416+ csv file (internally managed) if there is room in it, otherwise close it,
417+ silently open a new one and write to that.
418+
419+ Attributes:
420+ base_path:
421+ The base path of the output files - which will have a full path
422+ of form "<base_path><number>.csv"
423+ headers:
424+ A list or tuple of strings which will be used as the column
425+ headers. Attempts to write a row of data will induce a check
426+ that the length of the data provided is exactly that of this
427+ argument.
428+ max_rows:
429+ The maximum number of rows to write in each file (including the
430+ header row) before opening a new file.
431+
432+ _fmt:
433+ The format string which will be used to write a row to the csv
434+ file. Of form '{},{},...,{}\n'.
435+ _file:
436+ The currently open file.
437+ _file_count:
438+ The number of files that the CSVWriter has written to. The next
439+ file to be opened will have name '<base_path><_file_count>.csv'
440+ _row_count:
441+ The number of rows (including the header row) in the current file.
442+ _total_row_count:
443+ The number of rows (including the header rows) in ALL files.
444+
445+ """
446+ # Filename fmt of output files - used with .format(base_path, number).
447+ _file_fmt = '{}{}.csv'
448+
449+ def __init__(self, base_path, headers, max_rows):
450+ """
451+ Initialise the writer for writing.
452+
453+ Arguments:
454+ base_path:
455+ The base path of the output files - which will have a full path
456+ of form "<base_path><number>.csv"
457+ headers:
458+ A list or tuple of strings which will be used as the column
459+ headers. Attempts to write a row of data will induce a check
460+ that the length of the data provided is exactly that of this
461+ argument.
462+ max_rows:
463+ The maximum number of rows to write in each file (including the
464+ header row) before opening a new file.
465+ """
466+ self.base_path = base_path
467+ self.headers = headers
468+ self.max_rows = max_rows
469+
470+ self._fmt = ','.join('{}' for h in headers) + '\n'
471+
472+ # The number of the file we are on and the line in it.
473+ self._file = None
474+ self._file_count = 0
475+ self._row_count = 0
476+
477+ self._total_row_count = 0
478+
479+ self._open_new_file()
480+
481+ def _open_new_file(self):
482+ """
483+ Open a new file for editing, writing the headers in the first row.
484+ """
485+ self._close_file()
486+
487+ path = CSVWriter._file_fmt.format(self.base_path, self._file_count)
488+ self._file = open(path, 'w+')
489+ self._file_count += 1
490+ self.write(*self.headers)
491+
492+ def _close_file(self):
493+ """
494+ Close the current file.
495+
496+ NOTE that this method should ALWAYS be called before attempting to read
497+ from the file as it ensures that all changes have been written to disk,
498+ not only buffered.
499+ """
500+ if self._file and not self._file.closed:
501+ logging.debug('csvwriter wrote {} lines to {}'
502+ .format(self._row_count, self._file.name))
503+ self._file.close()
504+
505+ self._row_count = 0
506+
507+ def write(self, *args):
508+ """
509+ Add a row the to current file, or to a new one if max_rows is reached.
510+
511+ The check against max_rows is made BEFORE writing the line.
512+
513+ Raises:
514+ ValueError:
515+ If the length of *args is not exactly the length of
516+ self.headers - i.e. on attempt to write too many/too few items.
517+
518+ Arguments:
519+ *args:
520+ Strings, which will be written into the columns of the current
521+ open csv file.
522+ """
523+ if not len(args) == len(self.headers):
524+ msg = 'Attempted to write {} entries to file {} with {} columns'
525+ raise ValueError(msg.format(len(args), self.base_path,
526+ len(self.headers)))
527+
528+ if self._row_count == self.max_rows:
529+ self._close_file()
530+ self._open_new_file()
531+
532+ self._file.write(self._fmt.format(*args))
533+ self._row_count += 1
534+ self._total_row_count += 1
535+
536+ def file_iter(self):
537+ """
538+ Return an iterator over the names of the files the writer has
539+ written to.
540+ """
541+ fmt = CSVWriter._file_fmt
542+ return (fmt.format(self.base_path, i) for i in range(self._file_count))
543+
544+ def finish(self):
545+ """
546+ Flush and close the current file. If a subsequent call to self.write
547+ is made, a new file will be created to contain it.
548+
549+ Return the number of files we have written to and the total number
550+ of lines we have written.
551+ """
552+ self._close_file()
553+ return self._file_count, self._total_row_count
554+
555
556=== modified file 'src/sextant/db_api.py'
557--- src/sextant/db_api.py 2014-09-03 14:10:07 +0000
558+++ src/sextant/db_api.py 2014-10-23 12:33:12 +0000
559@@ -5,208 +5,348 @@
560 # -----------------------------------------
561 # API to interact with a Neo4J server: upload, query and delete programs in a DB
562
563-__all__ = ("Validator", "AddToDatabase", "FunctionQueryResult", "Function",
564+from __future__ import print_function
565+
566+__all__ = ("validate_query", "DBProgram", "FunctionQueryResult", "Function",
567 "SextantConnection")
568
569+from sys import stdout
570+
571 import re # for validation of function/program names
572 import logging
573 from datetime import datetime
574 import os
575 import getpass
576 from collections import namedtuple
577-
578-from neo4jrestclient.client import GraphDatabase
579-import neo4jrestclient.client as client
580-
581+import random
582+import socket
583+
584+import itertools
585+import subprocess
586+from time import time
587+
588+import neo4jrestclient.client as neo4jrestclient
589+
590+from sshmanager import SSHManager, SSHConnectionError
591+from csvwriter import CSVWriter
592+
593+# The directory on the local machine to which csv files will be written
594+# prior to copy over to the remote server.
595+TMP_DIR = '/tmp/sextant'
596+
597+# A function is deemed 'common' if it has more than this
598+# many connections.
599 COMMON_CUTOFF = 10
600-# a function is deemed 'common' if it has more than this
601-# many connections
602-
603-
604-class Validator():
605- """ Sanitises/checks strings, to prevent Cypher injection attacks"""
606-
607- @staticmethod
608- def validate(input_):
609- """
610- Checks whether we can allow a string to be passed into a Cypher query.
611- :param input_: the string we wish to validate
612- :return: bool(the string is allowed)
613- """
614- regex = re.compile(r'^[A-Za-z0-9\-:\.\$_@\*\(\)%\+,]+$')
615- return bool(regex.match(input_))
616-
617- @staticmethod
618- def sanitise(input_):
619- """
620- Strips harmful characters from the given string.
621- :param input_: string to sanitise
622- :return: the sanitised string
623- """
624- return re.sub(r'[^\.\-_a-zA-Z0-9]+', '', input_)
625-
626-
627-class AddToDatabase():
628- """Updates the database, adding functions/calls to a given program"""
629-
630- def __init__(self, program_name='', sextant_connection=None,
631- uploader='', uploader_id='', date=None):
632- """
633- Object which can be used to add functions and calls to a new program
634- :param program_name: the name of the new program to be created
635- (must already be validated against Validator)
636- :param sextant_connection: the SextantConnection to use for connections
637- :param uploader: string identifier of user who is uploading
638- :param uploader_id: string Unix user-id of logged-in user
639- :param date: string date of today
640- """
641- # program_name must be alphanumeric, to avoid injection attacks easily
642- if not Validator.validate(program_name):
643- return
644+
645+
646+
647+def set_common_cutoff(common_def):
648+ """
649+ Sets the number of incoming connections at which we deem a function 'common'
650+ Default is 10 (which is used if this method is never called).
651+ :param common_def: number of incoming connections
652+ """
653+ global COMMON_CUTOFF
654+ COMMON_CUTOFF = common_def
655+
656+
657+def validate_query(string):
658+ """
659+ Checks whether we can allow a string to be passed into a Cypher query.
660+ :param string: the string we wish to validate
661+ :return: bool(the string is allowed)
662+ """
663+ regex = re.compile(r'^[A-Za-z0-9\-:\.\$_@\*\(\)%\+,]+$')
664+ return bool(regex.match(string))
665+
666+
667+class DBProgram(object):
668+ """
669+ Representation of a program in the database.
670+
671+ Provides add_function and add_call methods which locally register functions
672+ and calls. The commit method uploads everything to the database.
673+
674+ Attributes:
675+ uploader, uploader_id, program_name, date:
676+ As in __init__.
677+
678+ _conn:
679+ The SextantConnection object managing the database connection.
680+ _ssh:
681+ The SSHManager object belonging to the SextantConnection.
682+ _db:
683+ The database object belonging to the SextantConnection.
684+
685+ _tmp_dir:
686+ The user-specific location of the local temporary directory.
687+
688+ func_writer:
689+ A CSVWriter object which manages the csv files containing the
690+ list of functions in the program.
691+ call_writer:
692+ A CSVWriter object which manages the csv files containing the
693+ list of function calls in the program.
694+
695+ add_func_query:
696+ A string for the cypher query used to create functions from a csv
697+ file.
698+ add_call_query:
699+ A string for the cypher query used to create funciton calls from
700+ a csv file.
701+ add_program_query:
702+ A string for the cypher query used to create the program node.
703+ """
704+
705+ def __init__(self, connection, program_name, uploader, uploader_id, date):
706+ """
707+ Initialise the database program.
708+
709+ A local temporary folder is created at 'TMP_DIR-<user_name>'.
710+ When functions or calls are added via the add_function/call methods,
711+ they are registered in csv files which are stored in this directory.
712+
713+ Committing the program copies these files to the neo4j server and
714+ cleans the local tmp folder.
715+
716+ Raises:
717+ ValueError:
718+ If the program_name is not alphanumeric.
719+ CommandError:
720+ If the command to create the temporary directory failed.
721+
722+ Arguments:
723+ connection:
724+ The SextantConnection object which manages the connection to
725+ the database.
726+ program_name:
727+ The name to register the program under in the database. Must be
728+ alphanumeric.
729+ uploader:
730+ The name of the user who uploaded the program.
731+ uploader_id:
732+ A numeric id of the user who uploaded the program.
733+ date:
734+ A string representing the upload date.
735+ """
736+ # Ensure an alphanumeric program name.
737+ if not validate_query(program_name):
738+ raise ValueError('program name must be alphanumeric, got: {}'
739+ .format(program_name));
740+
741+ self.uploader = uploader
742+ self.uploader_id = uploader_id
743
744 self.program_name = program_name
745- self.parent_database_connection = sextant_connection
746- self._functions = {}
747- self._funcs_tx = None # transaction for uploading functions
748- self._calls_tx = None # transaction for uploading relationships
749-
750- if self.parent_database_connection:
751- # we'll locally use db for short
752- db = self.parent_database_connection._db
753-
754- parent_function = db.nodes.create(name=program_name,
755- type='program',
756- uploader=uploader,
757- uploader_id=uploader_id,
758- date=date)
759- self._parent_id = parent_function.id
760-
761- self._funcs_tx = db.transaction(using_globals=False, for_query=True)
762- self._calls_tx = db.transaction(using_globals=False, for_query=True)
763-
764- self._connections = []
765-
766- @staticmethod
767- def _get_display_name(function_name):
768- """
769- Gets the name we will display to the user for this function name.
770-
771- For instance, if function_name were __libc_start_main@plt, we would
772- return ("__libc_start_main", "plt_stub"). The returned function type is
773- currently one of "plt_stub", "function_pointer" or "normal".
774-
775- :param function_name: the name straight from objdump of a function
776- :return: ("display name", "function type")
777-
778- """
779-
780- if function_name[-4:] == "@plt":
781- display_name = function_name[:-4]
782- function_group = "plt_stub"
783- elif function_name[:20] == "_._function_pointer_":
784- display_name = function_name
785- function_group = "function_pointer"
786- else:
787- display_name = function_name
788- function_group = "normal"
789-
790- return display_name, function_group
791-
792- def add_function(self, function_name):
793- """
794- Adds a function to the program, ready to be sent to the remote database.
795- If the function name is already in use, this method effectively does
796- nothing and returns True.
797-
798- :param function_name: a string which must be alphanumeric
799- :return: True if the request succeeded, False otherwise
800- """
801- if not Validator.validate(function_name):
802- return False
803- if self.class_contains_function(function_name):
804- return True
805-
806- display_name, function_group = self._get_display_name(function_name)
807-
808- query = ('START n = node({}) '
809- 'CREATE (n)-[:subject]->(m:func {{type: "{}", name: "{}"}}) '
810- 'RETURN m.name, id(m)')
811- query = query.format(self._parent_id, function_group, display_name)
812-
813- self._funcs_tx.append(query)
814-
815- self._functions[function_name] = function_name
816-
817- return True
818-
819- def class_contains_function(self, function_to_find):
820- """
821- Checks whether we contain a function with a given name.
822- :param function_to_find: string name of the function we wish to look up
823- :return: bool(the function exists in this AddToDatabase)
824- """
825- return function_to_find in self._functions
826-
827- def class_contains_call(self, function_calling, function_called):
828- """
829- Checks whether we contain a call between the two named functions.
830- :param function_calling: string name of the calling-function
831- :param function_called: string name of the called function
832- :return: bool(function_calling calls function_called in us)
833- """
834- return (function_calling, function_called) in self._connections
835-
836- def add_function_call(self, fn_calling, fn_called):
837- """
838- Adds a function call to the program, ready to be sent to the database.
839- Effectively does nothing if there is already a function call between
840- these two functions.
841- Function names must be alphanumeric for easy security purposes;
842- returns False if they fail validation.
843- :param fn_calling: the name of the calling-function as a string.
844- It should already exist in the AddToDatabase; if it does not,
845- this method will create a stub for it.
846- :param fn_called: name of the function called by fn_calling.
847- If it does not exist, we create a stub representation for it.
848- :return: True if successful, False otherwise
849- """
850- if not all((Validator.validate(fn_calling),
851- Validator.validate(fn_called))):
852- return False
853-
854- if not self.class_contains_function(fn_called):
855- self.add_function(fn_called)
856- if not self.class_contains_function(fn_calling):
857- self.add_function(fn_calling)
858-
859- if not self.class_contains_call(fn_calling, fn_called):
860- self._connections.append((fn_calling, fn_called))
861-
862- return True
863+ self.date = date
864+
865+ self._conn = connection
866+ self._ssh = connection._ssh
867+ self._db = connection._db
868+
869+ self._tmp_dir = '{}-{}'.format(TMP_DIR, getpass.getuser())
870+
871+ # Make the local tmp file - csv files will be written into here.
872+ try:
873+ os.makedirs(self._tmp_dir)
874+ except OSError as e:
875+ if e.errno == os.errno.EEXIST: # File already exists.
876+ pass
877+ else:
878+ raise e
879+
880+
881+ tmp_path = os.path.join(self._tmp_dir, '{}_{{}}'.format(program_name))
882+
883+ self.func_writer = CSVWriter(tmp_path.format('funcs'),
884+ headers=['name', 'type'],
885+ max_rows=5000)
886+ self.call_writer = CSVWriter(tmp_path.format('calls'),
887+ headers=['caller', 'callee'],
888+ max_rows=5000)
889+
890+ # Define the queries we use to upload the functions and calls.
891+ self.add_func_query = (' USING PERIODIC COMMIT 250'
892+ ' LOAD CSV WITH HEADERS FROM "file:{}" AS line'
893+ ' WITH line, toInt(line.id) as lineid'
894+ ' MATCH (n:program {{name: "{}"}})'
895+ ' CREATE (n)-[:subject]->(m:func {{name: line.name,'
896+ ' id: lineid, type: line.type}})')
897+
898+ self.add_call_query = (' USING PERIODIC COMMIT 250'
899+ ' LOAD CSV WITH HEADERS FROM "file:{}" AS line'
900+ ' MATCH (p:program {{name: "{}"}})'
901+ ' MATCH (p)-[:subject]->(n:func {{name: line.caller}})'
902+ ' USING INDEX n:func(name)'
903+ ' MATCH (p)-[:subject]->(m:func {{name: line.callee}})'
904+ ' USING INDEX m:func(name)'
905+ ' CREATE (n)-[r:calls]->(m)')
906+
907+ self.add_program_query = ('CREATE (p:program {{name: "{}", uploader: "{}", '
908+ ' uploader_id: "{}", date: "{}",'
909+ ' function_count: {}, call_count: {}}})')
910+
911+
912+ def __enter__(self):
913+ """
914+ Allow DBProgram to be used as a context manager.
915+ """
916+ return self
917+
918+ def __exit__(self, etype, evalue, etrace):
919+ """
920+ Make sure that all files are properly closed.
921+ """
922+ self.func_writer.finish()
923+ self.call_writer.finish()
924+
925+ # Propagate the error if there is one.
926+ return False if etype is not None else True
927+
928+ def add_function(self, name, typ='normal'):
929+ """
930+ Add a function.
931+
932+ Arguments:
933+ name:
934+ The name of the function.
935+ typ:
936+ The type of the function, may be any string, but standard types
937+ are:
938+ normal: we have the disassembly for this function
939+ stub: we have the name but not the disassembly - usually
940+ an imported library function.
941+ pointer: we know only that the function exists, not its
942+ name or details.
943+ """
944+ self.func_writer.write(name, typ)
945+
946+ def add_call(self, caller, callee):
947+ """
948+ Add a function call.
949+
950+ Arguments:
951+ caller:
952+ The name of the function making the call.
953+ callee:
954+ The name of the function called.
955+ """
956+ self.call_writer.write(caller, callee)
957+
958+
959+ def _copy_local_to_remote_tmp_dir(self):
960+ """
961+ Move local tmp files to the server ready for upload.
962+
963+ Return a tuple of iterators, the first over the paths on the remote
964+ machine of the function files, and the second over the paths of the
965+ call files.
966+ """
967+ print('Sending files to remote server...', end='')
968+ stdout.flush()
969+ remote_funcs = self._ssh.send_to_tmp_dir(self.func_writer.file_iter())
970+ remote_calls = self._ssh.send_to_tmp_dir(self.call_writer.file_iter())
971+ print('finished.')
972+ return remote_funcs, remote_calls
973+
974+ def _clean_tmp_files(self, remote_paths):
975+ """
976+ Delete temporary files on the local and remote machine.
977+
978+ Arguments:
979+ remote_paths:
980+ A list of the paths of the remote fils.
981+ """
982+ print('Cleaning temporary files...', end='')
983+ file_paths = list(itertools.chain(self.func_writer.file_iter(),
984+ self.call_writer.file_iter()))
985+
986+ for path in file_paths:
987+ os.remove(path)
988+
989+ os.rmdir(self._tmp_dir)
990+
991+ try:
992+ # If the parent sextant temp folder is empty, remove it.
993+ os.rmdir(TMP_DIR)
994+ except:
995+ # There is other stuff in TMP_DIR (i.e. from other users), so
996+ # leave it.
997+ pass
998+
999+ self._ssh.remove_from_tmp_dir(remote_paths)
1000+
1001+ print('done.')
1002+
1003+ def _create_db_constraints(self):
1004+ """
1005+ Create indexes in the database on program and function names.
1006+
1007+ The program name index is a constraint, which will also garuantee the
1008+ uniqueness of program names.
1009+ """
1010+ # Prepare a transaction object which we use to execute cypher queries.
1011+ tx = self._db.transaction(using_globals=False, for_query=True)
1012+
1013+ tx.append('CREATE CONSTRAINT ON (p:program) ASSERT p.name IS UNIQUE')
1014+ tx.append('CREATE INDEX ON :func(name)')
1015+
1016+ # Apply the transaction.
1017+ tx.commit()
1018
1019 def commit(self):
1020 """
1021- Call this when you are finished with the object.
1022- Changes are not synced to the remote database until this is called.
1023+ Insert the program into the database.
1024+
1025+ Move the local temp files created by our func_writer and call_writer
1026+ to the database server's temp directory. From there use cypher queries
1027+ to upload them into the database, before cleaning them up.
1028 """
1029- functions = self._funcs_tx.commit() # send off the function names
1030-
1031- # now functions is a list of QuerySequence objects, which each have a
1032- # .elements property which produces [['name', id]]
1033-
1034- id_funcs = dict([seq.elements[0] for seq in functions])
1035- logging.info('Functions uploaded. Uploading calls...')
1036-
1037- # so id_funcs is a dict with id_funcs['name'] == id
1038- for call in self._connections:
1039- query = ('MATCH n WHERE id(n) = {} '
1040- 'MATCH m WHERE id(m) = {} '
1041- 'CREATE (n)-[:calls]->(m)')
1042- query = query.format(id_funcs[self._get_display_name(call[0])[0]],
1043- id_funcs[self._get_display_name(call[1])[0]])
1044- self._calls_tx.append(query)
1045-
1046- self._calls_tx.commit()
1047+ # Ensure that the most recent files are flushed and closed.
1048+ func_file_count, func_line_count = self.func_writer.finish()
1049+ call_file_count, call_line_count = self.call_writer.finish()
1050+
1051+ # Account for the header line at the top of each file.
1052+ func_count = func_line_count - func_file_count
1053+ call_count = call_line_count - call_file_count
1054+
1055+ # Get the remote path names as iterators, then make lists of them
1056+ # so that we can iterate over them more than once.
1057+ remote_f_iter, remote_c_iter = self._copy_local_to_remote_tmp_dir()
1058+ remote_funcs, remote_calls = map(list, (remote_f_iter, remote_c_iter))
1059+
1060+ # Create the indexes and constraints in the database.
1061+ self._create_db_constraints()
1062+
1063+
1064+ try:
1065+ tx = self._db.transaction(using_globals=False, for_query=True)
1066+
1067+ # Create the program node in the database.
1068+ tx.append(self.add_program_query.format(self.program_name, self.uploader,
1069+ self.uploader_id, self.date,
1070+ func_count, call_count))
1071+ tx.commit()
1072+
1073+ # Create the functions.
1074+ for files, query, descr in zip((remote_funcs, remote_calls),
1075+ (self.add_func_query, self.add_call_query),
1076+ ('funcs', 'calls')):
1077+ start = time()
1078+ for i, path in enumerate(files):
1079+ completed = int(100*float(i+1)/len(files))
1080+
1081+ print('\rUploading {}: {}%'.format(descr, completed), end='')
1082+ stdout.flush()
1083+
1084+ tx.append(query.format(path, self.program_name))
1085+ tx.commit()
1086+ end = time()
1087+ print(' done.')
1088+
1089+ finally:
1090+ # Cleanup temporary folders
1091+ self._clean_tmp_files(remote_funcs + remote_calls)
1092
1093
1094 class FunctionQueryResult:
1095@@ -219,7 +359,7 @@
1096 self._update_common_functions()
1097
1098 def __eq__(self, other):
1099- # we make a dictionary so that we can perform easy comparison
1100+ # We make a dictionary so that we can perform easy comparison.
1101 selfdict = {func.name: func for func in self.functions}
1102 otherdict = {func.name: func for func in other.functions}
1103
1104@@ -243,20 +383,20 @@
1105 if rest_output is None or not rest_output.elements:
1106 return []
1107
1108- # how we store this is: a dict
1109+ # How we store this is: a dict
1110 # with keys 'functionname'
1111 # and values [the function object we will use,
1112 # and a set of (function names this function calls),
1113- # and numeric ID of this node in the Neo4J database]
1114+ # and numeric ID of this node in the Neo4J database].
1115
1116 result = {}
1117
1118- # initial pass for names of functions
1119+ # Initial pass for names of functions.
1120
1121- # if the following assertion failed, we've probably called db.query
1122+ # If the following assertion failed, we've probably called db.query
1123 # to get it to not return client.Node objects, which is wrong.
1124 # we attempt to handle this a bit later; this should never arise, but
1125- # we can cope with it happening in some cases, like the test suite
1126+ # we can cope with it happening in some cases, like the test suite.
1127
1128 if type(rest_output.elements) is not list:
1129 logging.warning('Not a list: {}'.format(type(rest_output.elements)))
1130@@ -264,11 +404,12 @@
1131 for node_list in rest_output.elements:
1132 assert(isinstance(node_list, list))
1133 for node in node_list:
1134- if isinstance(node, client.Node):
1135+ if isinstance(node, neo4jrestclient.Node):
1136 name = node.properties['name']
1137 node_id = node.id
1138 node_type = node.properties['type']
1139- else: # this is the handling we mentioned earlier;
1140+ else:
1141+ # This is the handling we mentioned earlier;
1142 # we are a dictionary instead of a list, as for some
1143 # reason we've returned Raw rather than Node data.
1144 # We should never reach this code, but just in case.
1145@@ -283,7 +424,7 @@
1146 set(),
1147 node_id]
1148
1149- # end initialisation of names-dictionary
1150+ # End initialisation of names-dictionary.
1151
1152 if self._parent_db_connection is not None:
1153 # This is the normal case, of extracting results from a server.
1154@@ -301,7 +442,7 @@
1155 logging.debug('exec')
1156 results = new_tx.execute()
1157
1158- # results is a list of query results, each of those being a list of
1159+ # Results is a list of query results, each of those being a list of
1160 # calls.
1161
1162 for call_list in results:
1163@@ -315,7 +456,7 @@
1164 # recall: set union is denoted by |
1165
1166 else:
1167- # we don't have a parent database connection.
1168+ # We don't have a parent database connection.
1169 # This has probably arisen because we created this object from a
1170 # test suite, or something like that.
1171 for node in rest_output.elements:
1172@@ -353,19 +494,10 @@
1173 func_list = [func for func in self.functions if func.name == name]
1174 return None if len(func_list) == 0 else func_list[0]
1175
1176-
1177-def set_common_cutoff(common_def):
1178- """
1179- Sets the number of incoming connections at which we deem a function 'common'
1180- Default is 10 (which is used if this method is never called).
1181- :param common_def: number of incoming connections
1182- """
1183- global COMMON_CUTOFF
1184- COMMON_CUTOFF = common_def
1185-
1186-
1187 class Function(object):
1188- """Represents a function which might appear in a FunctionQueryResult."""
1189+ """
1190+ Represents a function which might appear in a FunctionQueryResult.
1191+ """
1192
1193 def __eq__(self, other):
1194 funcs_i_call_list = {func.name for func in self.functions_i_call}
1195@@ -393,11 +525,11 @@
1196 self.name = function_name
1197 self.is_common = False
1198 self._number_calling_me = 0
1199- # care: _number_calling_me is not automatically updated, except by
1200+ # Care: _number_calling_me is not automatically updated, except by
1201 # any invocation of FunctionQueryResult._update_common_functions.
1202
1203
1204-class SextantConnection:
1205+class SextantConnection(object):
1206 """
1207 RESTful connection to a remote database.
1208 It can be used to create/delete/query programs.
1209@@ -406,56 +538,214 @@
1210 ProgramWithMetadata = namedtuple('ProgramWithMetadata',
1211 ['uploader', 'uploader_id',
1212 'program_name', 'date',
1213- 'number_of_funcs'])
1214-
1215- def __init__(self, url):
1216- self.url = url
1217- self._db = GraphDatabase(url)
1218-
1219- def new_program(self, name_of_program):
1220+ 'number_of_funcs', 'number_of_calls'])
1221+
1222+ @staticmethod
1223+ def _is_localhost(host, port):
1224+ """
1225+ Checks whether a host is an alias to localhost.
1226+
1227+ Raises socket.gaierror if the host was not found.
1228+ """
1229+ addr = socket.getaddrinfo(host, port)[0][4][0]
1230+ return addr in ('127.0.0.1', '::1')
1231+
1232+ @staticmethod
1233+ def _is_port_used(port):
1234+ """
1235+ Checks with the OS to see whether a port is open.
1236+
1237+ Beware: port is passed directly to the shell. Make sure it is an integer.
1238+ We raise ValueError if it is not.
1239+ :param port: integer port to check for openness
1240+ :return: bool(port is in use)
1241+ """
1242+ result = False
1243+
1244+ # We follow:
1245+ # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python
1246+ if not (isinstance(port, int) and port > 0):
1247+ raise ValueError('port {} must be a positive integer.'.format(port))
1248+
1249+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1250+ try:
1251+ sock.bind(('127.0.0.1', port))
1252+ except socket.error as e:
1253+ if e.errno == os.errno.EADDRINUSE:
1254+ result = True
1255+ else:
1256+ raise
1257+
1258+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1259+
1260+ return result # that is, the port is not used
1261+
1262+ @staticmethod
1263+ def _get_unused_port():
1264+ """
1265+ Returns a port number between 10000 and 50000 which is not currently open.
1266+ """
1267+
1268+ keep_going = True
1269+ while keep_going:
1270+ portnum = random.randint(10000, 50000)
1271+ keep_going = SextantConnection._is_port_used(portnum)
1272+ return portnum
1273+
1274+
1275+ def __enter__(self):
1276+ return self
1277+
1278+ def __exit__(self, etype, evalue, etrace):
1279+ self.close()
1280+ return False if etype is not None else True
1281+
1282+
1283+ def __init__(self, remotehost, remoteport, no_ssh_tunnel=False):
1284+ """
1285+ Initialise the database and ssh connections.
1286+
1287+ Arguments:
1288+ remotehost:
1289+ The remote host name to connect to.
1290+ remoteport:
1291+ The port number to connect to on the remote host.
1292+ no_ssh_tunnel:
1293+ Disables the SSHManager if True. Prevents program upload.
1294+ """
1295+
1296+ self.remote_host = remotehost
1297+ self.remote_port = remoteport
1298+
1299+
1300+ self._no_ssh_tunnel = no_ssh_tunnel
1301+ self._ssh = None
1302+ self._db = None
1303+
1304+ self.open()
1305+
1306+ def open(self):
1307+ local_port = SextantConnection._get_unused_port()
1308+ is_localhost = SextantConnection._is_localhost(self.remote_host, self.remote_port)
1309+
1310+ if self._no_ssh_tunnel and not is_localhost:
1311+ raise SSHConnectionError('Cannot connect to the remote database '
1312+ 'without an ssh connection.')
1313+ else:
1314+ # Either we are making an ssh tunnel or we are contacting localhost.
1315+ self._ssh = SSHManager(local_port,
1316+ self.remote_host,
1317+ self.remote_port,
1318+ is_localhost=is_localhost)
1319+
1320+ port = self.remote_port if is_localhost else local_port
1321+ url = 'http://localhost:{}'.format(port)
1322+
1323+ self._db = neo4jrestclient.GraphDatabase(url)
1324+
1325+ def close(self):
1326+ """
1327+ Close the ssh connection to clean up its resources.
1328+ """
1329+ if self._ssh:
1330+ self._ssh.close()
1331+
1332+ def new_program(self, program_name):
1333 """
1334 Request that the remote database create a new program with the given name.
1335 This procedure will create a new program remotely; you can manipulate
1336- that program using the returned AddToDatabase object.
1337+ that program using the returned DBProgram object.
1338 The name can appear in the database already, but this is not recommended
1339 because then delete_program will not know which to delete. Check first
1340 using self.check_program_exists.
1341- The name specified must pass Validator.validate()ion; this is a measure
1342+ The name specified must pass validate_query()ion; this is a measure
1343 to prevent Cypher injection attacks.
1344- :param name_of_program: string program name
1345- :return: AddToDatabase instance if successful
1346+ :param program_name: string program name
1347+ :return: DBProgram instance if successful
1348 """
1349
1350- if not Validator.validate(name_of_program):
1351- raise ValueError(
1352- "{} is not a valid program name".format(name_of_program))
1353+ if not validate_query(program_name):
1354+ raise ValueError("{} is not a valid program name"
1355+ .format(program_name))
1356
1357 uploader = getpass.getuser()
1358 uploader_id = os.getuid()
1359-
1360- return AddToDatabase(sextant_connection=self,
1361- program_name=name_of_program,
1362- uploader=uploader, uploader_id=uploader_id,
1363- date=str(datetime.now()))
1364-
1365- def delete_program(self, name_of_program):
1366+ timestr = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
1367+
1368+ return DBProgram(self, program_name, uploader,
1369+ uploader_id, date=timestr)
1370+
1371+ def delete_program(self, program_name):
1372 """
1373 Request that the remote database delete a specified program.
1374- :param name_of_program: a string which must be alphanumeric only
1375+ :param program_name: a string which must be alphanumeric only
1376 :return: bool(request succeeded)
1377 """
1378- if not Validator.validate(name_of_program):
1379- return False
1380-
1381- q = """MATCH (n) WHERE n.name= "{}" AND n.type="program"
1382- OPTIONAL MATCH (n)-[r]-(b) OPTIONAL MATCH (b)-[rel]-()
1383- DELETE b,rel DELETE n, r""".format(name_of_program)
1384-
1385- self._db.query(q)
1386+ if not program_name in self.get_program_names():
1387+ print('No program `{}` in the database'.format(program_name))
1388+ return True
1389+ else:
1390+ print('Deleting `{}` from the database. '
1391+ 'This may take some time for larger programs.'
1392+ .format(program_name))
1393+
1394+ start = time()
1395+ tx = self._db.transaction(using_globals=False, for_query=True)
1396+
1397+ count_query = (' MATCH (p:program {{name: "{}"}})'
1398+ ' RETURN p.function_count, p.call_count'
1399+ .format(program_name))
1400+
1401+ tx.append(count_query)
1402+ func_count, call_count = tx.commit()[0].elements[0]
1403+
1404+ del_call_query = ('OPTIONAL MATCH (p:program {{name: "{}"}})'
1405+ '-[:subject]->(f:func)-[c:calls]->()'
1406+ ' WITH c LIMIT 5000 DELETE c RETURN count(distinct(c))'
1407+ .format(program_name))
1408+
1409+ del_func_query = ('OPTIONAL MATCH (p:program {{name: "{}"}})'
1410+ '-[s:subject]->(f:func)'
1411+ ' WITH s, f LIMIT 5000 DELETE s, f RETURN count(f)'
1412+ .format(program_name))
1413+
1414+ del_prog_query = ('MATCH (p:program {{name: "{}"}}) DELETE p'
1415+ .format(program_name))
1416+
1417+ # Delete calls first, a node may not be deleted until all relationships
1418+ # referencing it are deleted.
1419+ for count, query, descr in zip((call_count, func_count),
1420+ (del_call_query, del_func_query),
1421+ ('calls', 'funcs')):
1422+ # Change tracks whether the last delete did anything. We would
1423+ # like to use: while done < count: ..., but if the program has
1424+ # already been partially deleted then this will never terminate.
1425+ # Furthermore, if there are no functions or no calls, the while
1426+ # loop will be appropriately skipped.
1427+ change = count
1428+ done = 0
1429+ while change:
1430+ completed = int(100 * float(done)/count)
1431+ print('\rDeleting {}: {}%'.format(descr, completed), end='')
1432+ stdout.flush()
1433+
1434+ tx.append(query)
1435+ change = tx.commit()[0].elements[0][0]
1436+ done += change
1437+ if done:
1438+ print(' done.')
1439+
1440+ # Delete the program node.
1441+ tx.append(del_prog_query)
1442+ tx.commit()
1443+
1444+ end = time()
1445+ print('Finished in {:.2f}s.'.format(end - start))
1446
1447 return True
1448
1449- def _execute_query(self, prog_name='', query=''):
1450+
1451+ def _execute_query(self, prog_name, query):
1452 """
1453 Executes a Cypher query against the remote database.
1454 Note that this returns a FunctionQueryResult, so is unsuitable for any
1455@@ -468,7 +758,7 @@
1456 :param query: verbatim query we wish the server to execute
1457 :return: a FunctionQueryResult corresponding to the server's output
1458 """
1459- rest_output = self._db.query(query, returns=client.Node)
1460+ rest_output = self._db.query(query, returns=neo4jrestclient.Node)
1461
1462 return FunctionQueryResult(parent_db=self._db,
1463 program_name=prog_name,
1464@@ -481,12 +771,11 @@
1465 method which requires a program-name input.
1466 :return: a list of function-name strings.
1467 """
1468- q = """MATCH (n) WHERE n.type = "program" RETURN n.name"""
1469+ q = 'MATCH (n:program) RETURN n.name'
1470 program_names = self._db.query(q, returns=str).elements
1471
1472- result = [el[0] for el in program_names]
1473+ return set(el[0] for el in program_names)
1474
1475- return set(result)
1476
1477 def programs_with_metadata(self):
1478 """
1479@@ -498,27 +787,28 @@
1480
1481 """
1482
1483- q = ("MATCH (base) WHERE base.type = 'program' "
1484- "MATCH (base)-[:subject]->(n)"
1485- "RETURN base.uploader, base.uploader_id, base.name, base.date, count(n)")
1486+ q = (' MATCH (p:program)'
1487+ ' RETURN p.uploader, p.uploader_id, p.name, p.date,'
1488+ ' p.function_count, p.call_count')
1489 result = self._db.query(q)
1490 return {self.ProgramWithMetadata(*res) for res in result}
1491
1492 def check_program_exists(self, program_name):
1493 """
1494 Execute query to check whether a program with the given name exists.
1495- Returns False if the program_name fails validation against Validator.
1496+ Returns False if the program_name fails validation (i.e. is possibly
1497+ unsafe as a string in a cypher query).
1498 :return: bool(the program exists in the database).
1499 """
1500
1501- if not Validator.validate(program_name):
1502+ if not validate_query(program_name):
1503 return False
1504
1505- q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' "
1506- "RETURN count(base)").format(program_name)
1507+ q = ('MATCH (p:program {{name: "{}"}}) RETURN p LIMIT 1'
1508+ .format(program_name))
1509
1510- result = self._db.query(q, returns=int)
1511- return result.elements[0][0] > 0
1512+ result = self._db.query(q, returns=neo4jrestclient.Node)
1513+ return bool(result)
1514
1515 def check_function_exists(self, program_name, function_name):
1516 """
1517@@ -529,18 +819,18 @@
1518 :param function_name: string name of the function to check for existence
1519 :return: bool(names validate correctly, and function exists in program)
1520 """
1521- if not self.check_program_exists(program_name):
1522- return False
1523-
1524- if not Validator.validate(program_name):
1525- return False
1526-
1527- q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program'"
1528- "MATCH (base)-[r:subject]->(m) WHERE m.name = '{}'"
1529- "RETURN count(m)").format(program_name, function_name)
1530-
1531- result = self._db.query(q, returns=int)
1532- return result.elements[0][0] > 0
1533+ if not validate_query(program_name):
1534+ return False
1535+
1536+ pmatch = '(:program {{name: "{}"}})'.format(program_name)
1537+ fmatch = '(f:func {{name: "{}"}})'.format(function_name)
1538+ # be explicit about index usage
1539+ q = (' MATCH {}-[:subject]->{} USING INDEX f:func(name)'
1540+ ' RETURN f LIMIT 1'.format(pmatch, fmatch))
1541+
1542+ # result will be an empty list if the function was not found
1543+ result = self._db.query(q, returns=neo4jrestclient.Node)
1544+ return bool(result)
1545
1546 def get_function_names(self, program_name):
1547 """
1548@@ -552,12 +842,11 @@
1549 a set of function-name strings otherwise.
1550 """
1551
1552- if not self.check_program_exists(program_name):
1553- return None
1554+ if not validate_query(program_name):
1555+ return set()
1556
1557- q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' "
1558- "MATCH (base)-[r:subject]->(m) "
1559- "RETURN m.name").format(program_name)
1560+ q = (' MATCH (:program {{name: "{}"}})-[:subject]->(f:func)'
1561+ ' RETURN f.name').format(program_name)
1562 return {func[0] for func in self._db.query(q)}
1563
1564 def get_all_functions_called(self, program_name, function_calling):
1565@@ -570,16 +859,13 @@
1566 :return: FunctionQueryResult, maximal subgraph rooted at function_calling
1567 """
1568
1569- if not self.check_program_exists(program_name):
1570- return None
1571-
1572 if not self.check_function_exists(program_name, function_calling):
1573 return None
1574
1575- q = """MATCH (base) WHERE base.name = '{}' ANd base.type = 'program'
1576- MATCH (base)-[:subject]->(m) WHERE m.name='{}'
1577- MATCH (m)-[:calls*]->(n)
1578- RETURN distinct n, m""".format(program_name, function_calling)
1579+ q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(f:func {{name: "{}"}})'
1580+ ' USING INDEX f:func(name)'
1581+ ' MATCH (f)-[:calls*]->(g) RETURN distinct f, g'
1582+ .format(program_name, function_calling))
1583
1584 return self._execute_query(program_name, q)
1585
1586@@ -593,16 +879,13 @@
1587 :return: FunctionQueryResult, maximal connected subgraph with leaf function_called
1588 """
1589
1590- if not self.check_program_exists(program_name):
1591- return None
1592-
1593 if not self.check_function_exists(program_name, function_called):
1594 return None
1595
1596- q = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program'
1597- MATCH (base)-[r:subject]->(m) WHERE m.name='{}'
1598- MATCH (n)-[:calls*]->(m) WHERE n.name <> '{}'
1599- RETURN distinct n , m"""
1600+ q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(g:func {{name: "{}"}})'
1601+ ' USING INDEX g:func(name)'
1602+ ' MATCH (f)-[:calls*]->(g) WHERE f.name <> "{}"'
1603+ ' RETURN distinct f , g')
1604 q = q.format(program_name, function_called, program_name)
1605
1606 return self._execute_query(program_name, q)
1607@@ -628,12 +911,14 @@
1608 if not self.check_function_exists(program_name, function_calling):
1609 return None
1610
1611- q = r"""MATCH (pr) WHERE pr.name = '{}' AND pr.type = 'program'
1612- MATCH p=(start {{name: "{}" }})-[:calls*]->(end {{name:"{}"}})
1613- WHERE (pr)-[:subject]->(start)
1614- WITH DISTINCT nodes(p) AS result
1615- UNWIND result AS answer
1616- RETURN answer"""
1617+ q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(start:func {{name: "{}"}})'
1618+ ' USING INDEX start:func(name)'
1619+ ' MATCH (p)-[:subject]->(end:func {{name: "{}"}})'
1620+ ' USING INDEX end:func(name)'
1621+ ' MATCH path=(start)-[:calls*]->(end)'
1622+ ' WITH DISTINCT nodes(path) AS result'
1623+ ' UNWIND result AS answer'
1624+ ' RETURN answer')
1625 q = q.format(program_name, function_calling, function_called)
1626
1627 return self._execute_query(program_name, q)
1628@@ -648,11 +933,9 @@
1629 if not self.check_program_exists(program_name):
1630 return None
1631
1632- query = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program'
1633- MATCH (base)-[subject:subject]->(m)
1634- RETURN DISTINCT (m)""".format(program_name)
1635-
1636- return self._execute_query(program_name, query)
1637+ q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(f:func)'
1638+ ' RETURN (f)'.format(program_name))
1639+ return self._execute_query(program_name, q)
1640
1641 def get_shortest_path_between_functions(self, program_name, func1, func2):
1642 """
1643@@ -671,9 +954,11 @@
1644 if not self.check_function_exists(program_name, func2):
1645 return None
1646
1647- q = """MATCH (func1 {{ name:"{}" }}),(func2 {{ name:"{}" }}),
1648- p = shortestPath((func1)-[:calls*]->(func2))
1649- UNWIND nodes(p) AS ans
1650- RETURN ans""".format(func1, func2)
1651+ q = (' MATCH (p:program {{name: "{}"}})-[:subject]->(f:func {{name: "{}"}})'
1652+ ' USING INDEX f:func(name)'
1653+ ' MATCH (p)-[:subject]->(g:func {{name: "{}"}})'
1654+ ' MATCH path=shortestPath((f)-[:calls*]->(g))'
1655+ ' UNWIND nodes(path) AS ans'
1656+ ' RETURN ans'.format(program_name, func1, func2))
1657
1658 return self._execute_query(program_name, q)
1659
1660=== modified file 'src/sextant/export.py'
1661--- src/sextant/export.py 2014-09-04 09:46:18 +0000
1662+++ src/sextant/export.py 2014-10-23 12:33:12 +0000
1663@@ -46,7 +46,7 @@
1664 font_name = "Helvetica"
1665
1666 for func in program.get_functions():
1667- if func.type == "plt_stub":
1668+ if func.type == "stub":
1669 output_str += ' "{}" [fillcolor=pink, style=filled]\n'.format(func.name)
1670 elif func.type == "function_pointer":
1671 output_str += ' "{}" [fillcolor=yellow, style=filled]\n'.format(func.name)
1672@@ -108,7 +108,7 @@
1673
1674 for func in program.get_functions():
1675 display_func = ProgramConverter.get_display_name(func)
1676- if func.type == "plt_stub":
1677+ if func.type == "stub":
1678 colour = "#ff00ff"
1679 elif func.type == "function_pointer":
1680 colour = "#99ffff"
1681@@ -175,4 +175,4 @@
1682 output_str += '<edge source="{}" target="{}"> <data key="calls">1</data> </edge>\n'.format(func.name, callee.name)
1683
1684 output_str += '</graph>\n</graphml>'
1685- return output_str
1686\ No newline at end of file
1687+ return output_str
1688
1689=== modified file 'src/sextant/objdump_parser.py' (properties changed: -x to +x)
1690--- src/sextant/objdump_parser.py 2014-08-18 13:00:53 +0000
1691+++ src/sextant/objdump_parser.py 2014-10-23 12:33:12 +0000
1692@@ -1,273 +1,313 @@
1693-# -----------------------------------------
1694-# Sextant
1695-# Copyright 2014, Ensoft Ltd.
1696-# Author: Patrick Stevens
1697-# -----------------------------------------
1698-
1699-#!/usr/bin/python3
1700-
1701-import re
1702+#!/usr/bin/python
1703 import argparse
1704-import os.path
1705 import subprocess
1706 import logging
1707
1708-
1709-class ParsedObject():
1710- """
1711- Represents a function as parsed from an objdump disassembly.
1712- Has a name (which is the verbatim name like '__libc_start_main@plt'),
1713- a position (which is the virtual memory location in hex, like '08048320'
1714- extracted from the dump),
1715- and a canonical_position (which is the virtual memory location in hex
1716- but stripped of leading 0s, so it should be a
1717- unique id).
1718- It also has a list what_do_i_call of ParsedObjects it calls using the
1719- assembly keyword 'call'.
1720- It has a list original_code of its assembler code, too, in case it's useful.
1721- """
1722-
1723- @staticmethod
1724- def get_canonical_position(position):
1725- return position.lstrip('0')
1726-
1727- def __eq__(self, other):
1728- return self.name == other.name
1729-
1730- def __init__(self, input_lines=None, assembler_section='', function_name='',
1731- ignore_function_pointers=True, function_pointer_id=None):
1732- """
1733- Create a new ParsedObject given the definition-lines from objdump -S.
1734- A sample first definition-line is '08048300 <__gmon_start__@plt>:\n'
1735- but this method
1736- expects to see the entire definition eg
1737-
1738-080482f0 <puts@plt>:
1739- 80482f0: ff 25 00 a0 04 08 jmp *0x804a000
1740- 80482f6: 68 00 00 00 00 push $0x0
1741- 80482fb: e9 e0 ff ff ff jmp 80482e0 <_init+0x30>
1742-
1743- We also might expect assembler_section, which is for instance '.init'
1744- in 'Disassembly of section .init:'
1745- function_name is used if we want to give this function a custom name.
1746- ignore_function_pointers=True will pretend that calls to (eg) *eax do
1747- not exist; setting to False makes us create stubs for those calls.
1748- function_pointer_id is only used internally; it refers to labelling
1749- of function pointers if ignore_function_pointers is False. Each
1750- stub is given a unique numeric ID: this parameter tells init where
1751- to start counting these IDs from.
1752-
1753- """
1754- if input_lines is None:
1755- # get around Python's inability to pass in empty lists by value
1756- input_lines = []
1757-
1758- self.name = function_name or re.search(r'<.+>', input_lines[0]).group(0).strip('<>')
1759- self.what_do_i_call = []
1760- self.position = ''
1761-
1762- if input_lines:
1763- self.position = re.search(r'^[0-9a-f]+', input_lines[0]).group(0)
1764- self.canonical_position = ParsedObject.get_canonical_position(self.position)
1765- self.assembler_section = assembler_section
1766- self.original_code = input_lines[1:]
1767+"""
1768+Provide a parser class to extract functions and calls from an objdump file,
1769+and a way to generate such a file from an object file.
1770+"""
1771+__all__ = ('Parser', 'run_objdump', 'FileNotFoundError')
1772+
1773+
1774+class FileNotFoundError(Exception):
1775+ """
1776+ Exception raised when Parser fails to open its file.
1777+ """
1778+ pass
1779+
1780+
1781+class Parser(object):
1782+ """
1783+ Extract functions and calls from an object file or an objdump output file.
1784+
1785+ Only the specified sections of the disassembled code will be parsed.
1786+
1787+ Attributes:
1788+ path:
1789+ Set to file_path in __init__.
1790+ _file:
1791+ Set to file_object in __init__.
1792+ sections:
1793+ Initialised by taking the sections argument to __init__ and
1794+ and converting it to a set.
1795+ ignore_ptrs:
1796+ Set to ignore_ptrs in __init__.
1797+
1798+ section_count:
1799+ The number of sections that have been parsed.
1800+ function_count:
1801+ The number of functions that have been parsed.
1802+ call_count:
1803+ The number of function calls that have been parsed.
1804+ function_ptr_count:
1805+ The number of function pointers that have been detected.
1806+ _known_stubs:
1807+ A set of the names of functions with type 'stub' that have been
1808+ parsed - used to avoid registering a stub multiple times.
1809+
1810+ """
1811+ def __init__(self, file_path, file_object=None,
1812+ sections=None, ignore_ptrs=False,
1813+ add_function=None, add_call=None,
1814+ started=None, finished=None):
1815+ """
1816+ Initialise the parser object.
1817+
1818+ Raises:
1819+ FileNotFoundError:
1820+ If file_object was not provided and file_path couldn't be
1821+ opened.
1822+
1823+ Arguments:
1824+ file_path:
1825+ The path of the objdump output file to parse, or the path of an
1826+ object file to run objdump on and then parse.
1827+ file_object:
1828+ None if file_path is the path to an object file.
1829+ OR the file object (providing 'for line in file_object')
1830+ sections:
1831+ A list of the names of the disassembly sections to parse. An mepty
1832+ list will result in all sections being parsed.
1833+ ignore_ptrs:
1834+ If True, calls to function pointers will be ignored during parsing.
1835+ add_function:
1836+ A function to call when a function is parsed. Takes:
1837+ name: name of the parsed function
1838+ type: type of the parsed function
1839+ add_call:
1840+ A function to call when a function call is passed. Takes:
1841+ caller: name of the calling function
1842+ callee: name of the called function
1843+ started:
1844+ A function to call when the parse begins. Takes:
1845+ parser: the Parser instance which has just began parsing..
1846+ finished:
1847+ A function to call when the parse completes. Takes:
1848+ parser: the Parser instance which has just finished parsing.
1849+ e.g. if add_function/call have been set to write into files,
1850+ then finished may be set to properly flush and close them.
1851+ """
1852+ self.path = file_path
1853+ try:
1854+ self._file = file_object or self._open_file(file_path)
1855+ except FileNotFoundError:
1856+ raise
1857+
1858+ self.sections = set(sections or [])
1859+ self.ignore_ptrs = ignore_ptrs
1860+
1861+ self.section_count = 0
1862+ self.function_count = 0
1863+ self.call_count = 0
1864+ self.function_ptr_count = 0
1865+
1866+ # Avoid adding duplicate function stubs (as these are detected from
1867+ # function calls so may be repeated).
1868+ self._known_stubs = set()
1869+
1870+ # By default print information to stdout.
1871+ def print_func(name, typ):
1872+ print('func {:25}{}'.format(name, typ))
1873+
1874+ def print_call(caller, callee):
1875+ print('call {:25}{:25}'.format(caller, callee))
1876+
1877+ def print_started(parser):
1878+ print('parse started: {}[{}]'.format(self.path, ', '.join(self.sections)))
1879+
1880+
1881+ def print_finished(parser):
1882+ print('parsed {} functions and {} calls'.format(self.function_count, self.call_count))
1883+
1884+ self.add_function = add_function or print_func
1885+ self.add_call = add_call or print_call
1886+ self.started = lambda: (started or print_started)(self)
1887+ self.finished = lambda: (finished or print_finished)(self)
1888+
1889+
1890+ def _get_function_ptr_name(self):
1891+ """
1892+ Return a name for a new function pointer.
1893+ """
1894+ name = 'func_ptr_{}'.format(self.function_ptr_count)
1895+ self.function_ptr_count += 1
1896+ return name
1897+
1898+ def _add_function_normal(self, name):
1899+ """
1900+ Add a function which we have full assembly code for.
1901+ """
1902+ self.add_function(name, 'normal')
1903+ self.function_count += 1
1904+
1905+ def _add_function_ptr(self, name):
1906+ """
1907+ Add a function pointer.
1908+ """
1909+ self.add_function(name, 'pointer')
1910+ self.function_count += 1
1911+
1912+ def _add_function_stub(self, name):
1913+ """
1914+ Add a function stub - we have its name but none of its internals.
1915+ """
1916+ if not name in self._known_stubs:
1917+ self._known_stubs.add(name)
1918+ self.add_function(name, 'stub')
1919+ self.function_count += 1
1920+
1921+ def _add_call(self, caller, callee):
1922+ """
1923+ Add a function call from caller to callee.
1924+ """
1925+ self.add_call(caller, callee)
1926+ self.call_count += 1
1927+
1928+ def parse(self):
1929+ """
1930+ Parse self._file.
1931+ """
1932+ self.started()
1933+
1934+ if self._file is not None:
1935+ in_section = False # if we are in one of self.sections
1936+ current_function = None # track the caller for function calls
1937+
1938+ for line in self._file:
1939+ if line.startswith('Disassembly'):
1940+ # 'Disassembly of section <name>:\n'
1941+ section = line.split(' ')[-1].rstrip(':\n')
1942+ in_section = section in self.sections if self.sections else True
1943+ if in_section:
1944+ self.section_count += 1
1945+
1946+ elif in_section:
1947+ if line.endswith('>:\n'):
1948+ # '<address> <<function_identifier>>:\n'
1949+ # with <function_identifier> of form:
1950+ # <function_name>[@plt]
1951+ function_identifier = line.split('<')[-1].split('>')[0]
1952+
1953+ if '@' in function_identifier:
1954+ current_function = function_identifier.split('@')[0]
1955+ self._add_function_stub(current_function)
1956+ else:
1957+ current_function = function_identifier
1958+ self._add_function_normal(current_function)
1959+
1960+ elif 'call ' in line or 'callq ' in line:
1961+ # WHITESPACE to prevent picking up function names
1962+ # containing 'call'
1963+
1964+ # '<hex>: <hex> [l]call [hex] <callee_info>\n'
1965+ callee_info = line.split(' ')[-1].rstrip('\n')
1966+
1967+ # Where <callee_info> is either
1968+ # 1) '*(<register>)' call to a fn pointer
1969+ # 2) '$<hex>,$<hex>' lcall to a fn pointer
1970+ # 3) '<<function_identifier>>' call to a named function
1971+ if '<' in callee_info and '>' in callee_info:
1972+ # call to a normal or stub function
1973+ # '<function_identifier>' is of form <name>[@/-/+]<...>
1974+ # from which we extract name
1975+ callee_is_ptr = False
1976+ function_identifier = callee_info.lstrip('<').rstrip('>\n')
1977+ if '@' in function_identifier:
1978+ callee = function_identifier.split('@')[0]
1979+ self._add_function_stub(callee)
1980+ else:
1981+ callee = function_identifier.split('-')[-1].split('+')[0]
1982+ # Do not add this fn now - it is a normal func
1983+ # so we know about it from elsewhere.
1984+
1985+ else:
1986+ # Some kind of function pointer call.
1987+ callee_is_ptr = True
1988+ if not self.ignore_ptrs:
1989+ callee = self._get_function_ptr_name()
1990+ self._add_function_ptr(callee)
1991+
1992+ # Add the call.
1993+ if not (self.ignore_ptrs and callee_is_ptr):
1994+ self._add_call(current_function, callee)
1995
1996- call_regex_compiled = (ignore_function_pointers and re.compile(r'\tcall. +[^\*]+\n')) or re.compile(r'\tcall. +.+\n')
1997-
1998- lines_where_i_call = [line for line in input_lines if call_regex_compiled.search(line)]
1999-
2000- if not ignore_function_pointers and not function_pointer_id:
2001- function_pointer_id = [1]
2002-
2003- for line in lines_where_i_call:
2004- # we'll catch call and callq for the moment
2005- called = (call_regex_compiled.search(line).group(0))[8:].lstrip(' ').rstrip('\n')
2006- if called[0] == '*' and ignore_function_pointers == False:
2007- # we have a function pointer, which we'll want to give a distinct name
2008- address = '0'
2009- name = '_._function_pointer_' + str(function_pointer_id[0])
2010- function_pointer_id[0] += 1
2011-
2012- self.what_do_i_call.append((address, name))
2013-
2014- else: # we're not on a function pointer
2015- called_split = called.split(' ')
2016- if len(called_split) == 2:
2017- address, name = called_split
2018- name = name.strip('<>')
2019- # we still want to remove address offsets like +0x09 from the end of name
2020- match = re.match(r'^.+(?=\+0x[a-f0-9]+$)', name)
2021- if match is not None:
2022- name = match.group(0)
2023- self.what_do_i_call.append((address, name.strip('<>')))
2024- else: # the format of the "what do i call" is not recognised as a name/address pair
2025- self.what_do_i_call.append(tuple(called_split))
2026-
2027- def __str__(self):
2028- if self.position:
2029- return 'Memory address ' + self.position + ' with name ' + self.name + ' in section ' + str(
2030- self.assembler_section)
2031+ self.finished()
2032+
2033+ self._file.close()
2034+ result = True
2035 else:
2036- return 'Name ' + self.name
2037-
2038- def __repr__(self):
2039- out_str = 'Disassembly of section ' + self.assembler_section + ':\n\n' + self.position + ' ' + self.name + ':\n'
2040- return out_str + '\n'.join([' ' + line for line in self.original_code])
2041-
2042-
2043-class Parser:
2044- # Class to manipulate the output of objdump
2045-
2046- def __init__(self, input_file_location='', file_contents=None, sections_to_view=None, ignore_function_pointers=False):
2047- """Creates a new Parser, given an input file path. That path should be an output from objdump -D.
2048- Alternatively, supply file_contents, as a list of each line of the objdump output. We expect newlines
2049- to have been stripped from the end of each of these lines.
2050- sections_to_view makes sure we only use the specified sections (use [] for 'all sections' and None for none).
2051- """
2052- if file_contents is None:
2053- file_contents = []
2054-
2055- if sections_to_view is None:
2056- sections_to_view = []
2057-
2058- if input_file_location:
2059- file_to_read = open(input_file_location, 'r')
2060- self.source_string_list = [line for line in file_to_read]
2061- file_to_read.close()
2062- elif file_contents:
2063- self.source_string_list = [string + '\n' for string in file_contents]
2064- self.parsed_objects = []
2065- self.sections_to_view = sections_to_view
2066- self.ignore_function_pointers = ignore_function_pointers
2067- self.pointer_identifier = [1]
2068-
2069- def create_objects(self):
2070- """ Go through the source_string_list, getting object names (like 'main') along with the corresponding
2071- definitions, and put them into parsed_objects """
2072- if self.sections_to_view is None:
2073- return
2074-
2075- is_in_section = lambda name: self.sections_to_view == [] or name in self.sections_to_view
2076-
2077- parsed_objects = []
2078- current_object = []
2079- current_section = ''
2080- regex_compiled_addr_and_name = re.compile(r'[0-9a-f]+ <.+>:\n')
2081- regex_compiled_section = re.compile(r'section .+:\n')
2082-
2083- for line in self.source_string_list[4:]: # we bodge, since the file starts with a little bit of guff
2084- if regex_compiled_addr_and_name.match(line):
2085- # we are a starting line
2086- current_object = [line]
2087- elif re.match(r'Disassembly of section', line):
2088- current_section = regex_compiled_section.search(line).group(0).lstrip('section ').rstrip(':\n')
2089- current_object = []
2090- elif line == '\n':
2091- # we now need to stop parsing the current block, and store it
2092- if len(current_object) > 0 and is_in_section(current_section):
2093- parsed_objects.append(ParsedObject(input_lines=current_object, assembler_section=current_section,
2094- ignore_function_pointers=self.ignore_function_pointers,
2095- function_pointer_id=self.pointer_identifier))
2096- else:
2097- current_object.append(line)
2098-
2099- # now we should be done. We assumed that blocks begin with r'[0-9a-f]+ <.+>:\n' and end with a newline.
2100- # clear duplicates:
2101-
2102- self.parsed_objects = []
2103- for obj in parsed_objects:
2104- if obj not in self.parsed_objects: # this is so that if we jump into the function at an offset,
2105- # we still register it as being the old function, not some new function at a different address
2106- # with the same name
2107- self.parsed_objects.append(obj)
2108-
2109- # by this point, each object contains a self.what_do_i_call which is a list of tuples
2110- # ('address', 'name') if the address and name were recognised, or else (thing1, thing2, ...)
2111- # where the instruction was call thing1 thing2 thing3... .
2112-
2113- def object_lookup(self, object_name='', object_address=''):
2114- """Returns the object with name object_name or address object_address (at least one must be given).
2115- If objects with the given name or address
2116- are not found, returns None."""
2117-
2118- if object_name == '' and object_address == '':
2119- return None
2120-
2121- trial_obj = self.parsed_objects
2122-
2123- if object_name != '':
2124- trial_obj = [obj for obj in trial_obj if obj.name == object_name]
2125-
2126- if object_address != '':
2127- trial_obj = [obj for obj in trial_obj if
2128- obj.canonical_position == ParsedObject.get_canonical_position(object_address)]
2129-
2130- if len(trial_obj) == 0:
2131- return None
2132-
2133- return trial_obj
2134-
2135-def get_parsed_objects(filepath, sections_to_view, not_object_file, readable=False, ignore_function_pointers=False):
2136- if sections_to_view is None:
2137- sections_to_view = [] # because we use None for "no sections"; the intent of not providing any sections
2138- # on the command line was to look at all sections, not none
2139-
2140- # first, check whether the given file exists
2141- if not os.path.isfile(filepath):
2142- # we'd like to use FileNotFoundError, but we might be running under
2143- # Python 2, which doesn't have it.
2144- raise IOError(filepath + 'is not found.')
2145-
2146- #now the file should exist
2147- if not not_object_file: #if it is something we need to run through objdump first
2148- #we need first to run the object file through objdump
2149-
2150- objdump_file_contents = subprocess.check_output(['objdump', '-D', filepath])
2151- objdump_str = objdump_file_contents.decode('utf-8')
2152-
2153- p = Parser(file_contents=objdump_str.split('\n'), sections_to_view=sections_to_view, ignore_function_pointers=ignore_function_pointers)
2154- else:
2155+ result = False
2156+
2157+ return result
2158+
2159+ def _open_file(self, path):
2160+ """
2161+ Open and return the file at path.
2162+
2163+ Raises:
2164+ FileNotFoundError:
2165+ If the file fails to open.
2166+
2167+ Arguments:
2168+ path:
2169+ The path of the file to open.
2170+ """
2171 try:
2172- p = Parser(input_file_location=filepath, sections_to_view=sections_to_view, ignore_function_pointers=ignore_function_pointers)
2173- except UnicodeDecodeError:
2174- logging.error('File could not be parsed as a string. Did you mean to supply --object-file?')
2175- return False
2176-
2177- if readable: # if we're being called from the command line
2178- print('File read; beginning parse.')
2179- #file is now read, and we start parsing
2180-
2181- p.create_objects()
2182- return p.parsed_objects
2183+ result = open(path)
2184+ except Exception as e:
2185+ raise FileNotFoundError("parser failed to open `{}`: {}".format(path, e.strerror))
2186+
2187+ return result
2188+
2189+
2190+def run_objdump(input_file):
2191+ """
2192+ Run the objdump command on the file with the given path.
2193+
2194+ Return the input file path and a file object representing the result of
2195+ the objdump.
2196+
2197+ Arguments:
2198+ input_file:
2199+ The path of the file to run objdump on.
2200+
2201+ """
2202+ # A single section can be specified for parsing with the -j flag,
2203+ # but it is not obviously possible to parse multiple sections like this.
2204+ p = subprocess.Popen(['objdump', '-d', input_file, '--no-show-raw-insn'],
2205+ stdout=subprocess.PIPE)
2206+ g = subprocess.Popen(['egrep', 'Disassembly|call(q)? |>:$'], stdin=p.stdout, stdout=subprocess.PIPE)
2207+ return input_file, g.stdout
2208+
2209
2210 def main():
2211- argumentparser = argparse.ArgumentParser(description="Parse the output of objdump.")
2212- argumentparser.add_argument('--filepath', metavar="FILEPATH", help="path to input file", type=str, nargs=1)
2213- argumentparser.add_argument('--not-object-file', help="import text objdump output instead of the compiled file", default=False,
2214- action='store_true')
2215- argumentparser.add_argument('--sections-to-view', metavar="SECTIONS",
2216- help="sections of disassembly to view, like '.text'; leave blank for 'all'",
2217- type=str, nargs='*')
2218- argumentparser.add_argument('--ignore-function-pointers', help='whether to skip parsing calls to function pointers', action='store_true', default=False)
2219-
2220- parsed = argumentparser.parse_args()
2221+ """
2222+ Run the parser from the command line.
2223+
2224+ The path of the target file, the sections to view and the ignore function
2225+ pointers flag are set with command line arguments.
2226+ """
2227+ ap = argparse.ArgumentParser(description="Parse the output of objdump.")
2228+ ap.add_argument('--filepath', metavar="FILEPATH",
2229+ help="path to input file", type=str, nargs=1)
2230+
2231+ ap.add_argument('--sections-to-view', metavar="SECTIONS",
2232+ help="disassembly sections to view, eg '.text'; leave blank for 'all'",
2233+ type=str, nargs='*')
2234+ ap.add_argument('--ignore-function-pointers',
2235+ help='skip parsing calls to function pointers',
2236+ action='store_true', default=False)
2237+
2238+ args = ap.parse_args()
2239
2240- filepath = parsed.filepath[0]
2241- sections_to_view = parsed.sections_to_view
2242- not_object_file = parsed.not_object_file
2243- readable = True
2244- function_pointers = parsed.ignore_function_pointers
2245-
2246- parsed_objs = get_parsed_objects(filepath, sections_to_view, not_object_file, readable, function_pointers)
2247- if parsed_objs is False:
2248- return 1
2249-
2250- if readable:
2251- for named_function in parsed_objs:
2252- print(named_function.name)
2253- print([f[-1] for f in named_function.what_do_i_call]) # use [-1] to get the last element, since:
2254- #either we are in ('address', 'name'), when we want the last element, or else we are in (thing1, thing2, ...)
2255- #so for the sake of argument we'll take the last thing
2256-
2257-if __name__ == "__main__":
2258+ filepath = args.filepath[0]
2259+ sections = args.sections_to_view
2260+ ignore_ptrs = args.ignore_function_pointers
2261+
2262+ parser = Parser(filepath, sections, ignore_ptrs)
2263+ parser.parse()
2264+
2265+
2266+if __name__ == '__main__':
2267 main()
2268
2269=== modified file 'src/sextant/query.py'
2270--- src/sextant/query.py 2014-08-26 16:33:20 +0000
2271+++ src/sextant/query.py 2014-10-23 12:33:12 +0000
2272@@ -14,7 +14,7 @@
2273 from .export import ProgramConverter
2274
2275
2276-def query(remote_neo4j, input_query, display_neo4j='', program_name=None,
2277+def query(connection, display_neo4j='', program_name=None,
2278 argument_1=None, argument_2=None, suppress_common=False):
2279 """
2280 Run a query against the database at remote_neo4j.
2281@@ -36,24 +36,24 @@
2282
2283 """
2284
2285- if display_neo4j:
2286- display_url = display_neo4j
2287- else:
2288- display_url = remote_neo4j
2289+ # if display_neo4j:
2290+ # display_url = display_neo4j
2291+ # else:
2292+ # display_url = remote_neo4j
2293
2294- try:
2295- db = db_api.SextantConnection(remote_neo4j)
2296- except requests.exceptions.ConnectionError as err:
2297- logging.error("Could not connect to Neo4J server {}. Are you sure it is running?".format(display_url))
2298- logging.error(str(err))
2299- return 2
2300- #Not supported in python 2
2301- #except (urllib.exceptions.MaxRetryError):
2302- # logging.error("Connection was refused to {}. Are you sure the server is running?".format(remote_neo4j))
2303- # return 2
2304- except Exception as err:
2305- logging.exception(str(err))
2306- return 2
2307+ # try:
2308+ # db = db_api.SextantConnection(remote_neo4j)
2309+ # except requests.exceptions.ConnectionError as err:
2310+ # logging.error("Could not connect to Neo4J server {}. Are you sure it is running?".format(display_url))
2311+ # logging.error(str(err))
2312+ # return 2
2313+ # #Not supported in python 2
2314+ # #except (urllib.exceptions.MaxRetryError):
2315+ # # logging.error("Connection was refused to {}. Are you sure the server is running?".format(remote_neo4j))
2316+ # # return 2
2317+ # except Exception as err:
2318+ # logging.exception(str(err))
2319+ # return 2
2320
2321 prog = None
2322 names_list = None
2323@@ -66,38 +66,38 @@
2324 if argument_1 is None:
2325 print('Supply one function name to functions-calling.')
2326 return 1
2327- prog = db.get_all_functions_calling(program_name, argument_1)
2328+ prog = connection.get_all_functions_calling(program_name, argument_1)
2329 elif input_query == 'functions-called-by':
2330 if argument_1 is None:
2331 print('Supply one function name to functions-called-by.')
2332 return 1
2333- prog = db.get_all_functions_called(program_name, argument_1)
2334+ prog = connection.get_all_functions_called(program_name, argument_1)
2335 elif input_query == 'all-call-paths':
2336 if argument_1 is None and argument_2 is None:
2337 print('Supply two function names to calls-between.')
2338 return 1
2339- prog = db.get_call_paths(program_name, argument_1, argument_2)
2340+ prog = connection.get_call_paths(program_name, argument_1, argument_2)
2341 elif input_query == 'whole-program':
2342- prog = db.get_whole_program(program_name)
2343+ prog = connection.get_whole_program(program_name)
2344 elif input_query == 'shortest-call-path':
2345 if argument_1 is None and argument_2 is None:
2346 print('Supply two function names to shortest-path.')
2347 return 1
2348- prog = db.get_shortest_path_between_functions(program_name, argument_1, argument_2)
2349+ prog = connection.get_shortest_path_between_functions(program_name, argument_1, argument_2)
2350 elif input_query == 'functions':
2351 if program_name is not None:
2352- func_names = db.get_function_names(program_name)
2353+ func_names = connection.get_function_names(program_name)
2354 if func_names:
2355 names_list = list(func_names)
2356 else:
2357 print('No functions were found in program %s on server %s.' % (program_name, display_url))
2358 else:
2359- list_of_programs = db.get_program_names()
2360+ list_of_programs = connection.get_program_names()
2361 if not list_of_programs:
2362 print('Server %s database empty.' % (display_url))
2363 return 0
2364
2365- func_list = [db.get_function_names(prog_name)
2366+ func_list = [connection.get_function_names(prog_name)
2367 for prog_name in list_of_programs]
2368
2369 if not func_list:
2370@@ -105,7 +105,7 @@
2371 else:
2372 names_list = func_list
2373 elif input_query == 'programs':
2374- list_found = list(db.get_program_names())
2375+ list_found = list(connection.get_program_names())
2376 if not list_found:
2377 print('No programs were found on server {}.'.format(display_url))
2378 else:
2379@@ -122,7 +122,5 @@
2380 print('Nothing was returned from the query.')
2381
2382
2383-def audit(remote_neo4j):
2384- db = db_api.SextantConnection(remote_neo4j)
2385-
2386- return db.programs_with_metadata()
2387+def audit(connection):
2388+ return connection.programs_with_metadata()
2389
2390=== added file 'src/sextant/sshmanager.py'
2391--- src/sextant/sshmanager.py 1970-01-01 00:00:00 +0000
2392+++ src/sextant/sshmanager.py 2014-10-23 12:33:12 +0000
2393@@ -0,0 +1,278 @@
2394+import os
2395+import getpass
2396+import logging
2397+import subprocess
2398+
2399+"""Provide a class to manage an SSH tunnel and controller"""
2400+__all__ = ('SSHConnectionError', 'SSHCommandError', 'SSHManager')
2401+
2402+# The location of the temporary directory to create on the REMOTE machine.
2403+# Temporary files will be scp'd here prior to upload to the neo4j database.
2404+TMP_DIR = '/tmp/sextant'
2405+
2406+
2407+class SSHConnectionError(Exception):
2408+ """
2409+ An exception raised when an attempt to establish an ssh conneciton fails.
2410+ """
2411+ pass
2412+
2413+
2414+class SSHCommandError(Exception):
2415+ """
2416+ An exception raised when an attempt to run a command over ssh fails.
2417+ """
2418+ pass
2419+
2420+
2421+class SSHManager(object):
2422+ """
2423+ Manage an ssh tunnel with port forwarding.
2424+
2425+ Attributes:
2426+ local_port:
2427+ The port number on the local machine to forward.
2428+ remote_host:
2429+ The host to ssh into.
2430+ remote_port:
2431+ The port number on the remote host to connect to.
2432+ ssh_user:
2433+ The username to use for sshing - defaults to None, in which case
2434+ the ssh connection uses the username of the user who ran sextant.
2435+
2436+ _controller_name:
2437+ The base of the identifying name for the ssh controller - the
2438+ actual name will be a combination of this and the local port.
2439+ _is_localhost:
2440+ True if we are trying to ssh into localhost. In this case do not
2441+ open the tunnel, just provide the right api so the rest of Sextant
2442+ need not special case.
2443+ """
2444+
2445+ def __init__(self, local_port, remote_host, remote_port,
2446+ ssh_user=None, is_localhost=False):
2447+ """
2448+ Open an SSH tunnel with multiplexing enabled.
2449+
2450+ Raises:
2451+ ValueError:
2452+ If local_port or remote_port are not positive integers
2453+
2454+ Arguments:
2455+ local_port:
2456+ The number of the local port to forward.
2457+ remote_host:
2458+ The name of the remote host to connect to.
2459+ remote_port:
2460+ The port number on the remote host to connect to.
2461+ ssh_user:
2462+ An alternative user name to use for the ssh login.
2463+ is_localhost:
2464+ True if we are trying to ssh into localhost.
2465+ """
2466+ if not (isinstance(local_port, int) and local_port > 0):
2467+ raise ValueError(
2468+ 'Local port {} must be a positive integer.'.format(local_port))
2469+ if not (isinstance(remote_port, int) and remote_port > 0):
2470+ raise ValueError(
2471+ 'Remote port {} must be a positive integer.'.format(remote_port))
2472+
2473+ self.local_port = local_port
2474+ self.remote_host = remote_host
2475+ self.remote_port = remote_port
2476+ self.ssh_user = ssh_user
2477+
2478+ self._tmp_dir = '{}-{}'.format(TMP_DIR, self.ssh_user or getpass.getuser())
2479+
2480+ self._controller_name = 'sextantcontroller{}'.format(local_port)
2481+ self._is_localhost = is_localhost
2482+
2483+ self._open()
2484+
2485+ def _open(self):
2486+ """
2487+ Helper function to open the SSH tunnel.
2488+
2489+ Raises:
2490+ SSHConnectionError:
2491+ If the ssh command failed to run.
2492+ """
2493+ if self._is_localhost:
2494+ return
2495+
2496+ # This cmd string will be .format()ed in a few lines' time.
2497+ cmd = ['ssh']
2498+
2499+ if self.ssh_user:
2500+ # ssh -l {user} ... sets the remote login username
2501+ cmd.extend(['-l', self.ssh_user])
2502+
2503+ # -L localport:localhost:remoteport forwards the port.
2504+ port_fwd = '{}:localhost:{}'.format(self.local_port, self.remote_port)
2505+
2506+ # -M makes SSH able to accept slave connections.
2507+ # -S sets the location of a control socket (in this case, sextantcontroller.
2508+ # with a unique identifier appended, just in case we run sextant twice.
2509+ # simultaneously), so we know how to close the port again.
2510+ # -f goes into background; -N does not execute a remote command;
2511+ # -T says to remote host that we don't want a text shell.
2512+ cmd.extend(['-M', '-S', self._controller_name, '-fNT',
2513+ '-L', port_fwd, self.remote_host])
2514+
2515+ logging.debug('Opening SSH tunnel with cmd: {}'.format(' '.join(cmd)))
2516+
2517+ rc = subprocess.call(cmd)
2518+ if rc:
2519+ raise SSHConnectionError('SSH setup failed with error {}'.format(rc))
2520+
2521+ logging.debug('SSH tunnel created')
2522+
2523+ self._make_tmp_dir()
2524+
2525+ def close(self):
2526+ """
2527+ Close the SSH tunnel after cleaning the temp directory.
2528+ """
2529+ if self._is_localhost:
2530+ return
2531+
2532+ # ssh -O sends a command to the slave specified in -S, -q for quiet.
2533+ cmd = ['ssh', '-S', self._controller_name,
2534+ '-O', 'exit', '-q', self.remote_host]
2535+
2536+ logging.debug('Shutting down SSH tunnel with cmd: `{}`'
2537+ .format(' '.join(cmd)))
2538+
2539+ # SSH has a bug on some systems which causes it to ignore the -q flag
2540+ # meaning it prints "Exit request sent." to stderr.
2541+ # To avoid this, we grab stderr temporarily, and see if it's that string;
2542+ # if it is, suppress it.
2543+ pr = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
2544+ stdout, stderr = pr.communicate()
2545+ if stderr.rstrip() != 'Exit request sent.':
2546+ logging.error('SSH shutdown stderr: {}'.format(stderr))
2547+
2548+ if pr.returncode == 0:
2549+ logging.debug('Shut down successfully')
2550+ else:
2551+ logging.error('SSH shutdown failed with code {}'
2552+ .format(pr.returncode))
2553+
2554+ # Clean the temporary directory we created earlier.
2555+ self._delete_tmp_dir()
2556+
2557+ def _call(self, *args):
2558+ """
2559+ Execute a command on the remote machine over SSH.
2560+
2561+ Return a tuple of rc, stdout, stderr from the process call.
2562+
2563+ Arguments:
2564+ *args:
2565+ Strings containing the individual words of the command to
2566+ execute. E.g. _call('ls', '-lh', '.').
2567+ """
2568+ if self._is_localhost:
2569+ return (1, None, 'Cannot call SSH command from localhost')
2570+
2571+ ssh_cmd = ['ssh', '-S', self._controller_name, self.remote_host]
2572+ ssh_cmd.extend(args)
2573+ p = subprocess.Popen(ssh_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
2574+ stdout, stderr = p.communicate()
2575+
2576+ if p.returncode:
2577+ logging.debug('Call to `{}` failed with code: {}, stderr: {}'
2578+ .format(' '.join(ssh_cmd), p.returncode, stderr))
2579+
2580+ return p.returncode, stdout, stderr
2581+
2582+ def _make_tmp_dir(self):
2583+ """
2584+ Create the per-user temporary directory on the remote machine.
2585+ """
2586+ self._call('mkdir', '-p', self._tmp_dir)
2587+
2588+ def _delete_tmp_dir(self):
2589+ """
2590+ Remove the temporary directory on the remote machine.
2591+ """
2592+ self._call('rm', '-r', self._tmp_dir)
2593+
2594+
2595+ def send_to_tmp_dir(self, path_list):
2596+ """
2597+ Send the specified files to the temporary directory on the remote machine.
2598+
2599+ Return an iterator of save paths on the remote machine.
2600+ Raises:
2601+ ValueError:
2602+ If no file paths were provided, or if one or more of the
2603+ provided paths is not an actual file.
2604+ SSHCommandError:
2605+ If the scp command failed for any reason.
2606+
2607+ Arguments:
2608+ path_list:
2609+ Iterator of paths to the files on the local machine. All files
2610+ will be checked before copying to ensure that they exist and
2611+ to prevent passing arbitrary arguments to the ssh _call
2612+ command.
2613+ """
2614+ if not path_list:
2615+ raise ValueError('attempt to copy zero files')
2616+
2617+ # If we are in localhost, we are not controlling the TMP_DIR,
2618+ # so the files are already there.
2619+ if self._is_localhost:
2620+ return path_list
2621+
2622+ # Make sure we can take the len of path_list and iterate over it
2623+ # more than once.
2624+ path_list = list(path_list)
2625+
2626+ # Check that actual files are being copied - not random strings.
2627+ to_copy = [f for f in path_list if os.path.isfile(f)]
2628+
2629+ if len(to_copy) < len(path_list):
2630+ missed = [f for f in path_list if not f in to_copy]
2631+ raise ValueError('Attempted to copy non existant files: {}'
2632+ .format(', '.join(missed)))
2633+
2634+ scp_cmd = ['scp']
2635+ scp_cmd.extend(to_copy)
2636+ scp_cmd.append('{}:{}'.format(self.remote_host, self._tmp_dir))
2637+
2638+ proc = subprocess.Popen(scp_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
2639+ rc = proc.wait()
2640+ if rc:
2641+ raise SSHCommandError('scp failed with code {}: {}'.format(rc, stderr))
2642+
2643+ return (os.path.join(self._tmp_dir, os.path.basename(f)) for f in to_copy)
2644+
2645+ def remove_from_tmp_dir(self, path_list):
2646+ """
2647+ Delete the files specified as arguments from the remote machine.
2648+
2649+ The output of send_to_tmp_dir may be passed as input to this function.
2650+
2651+ Raises:
2652+ SSHCommandError:
2653+ If the rm command fails for any reason.
2654+
2655+ Arguments:
2656+ path_list:
2657+ Iterator of paths of the files on the remote machine, relative
2658+ to the temporary directory. E.g. remove_from_tmp_dir('foo')
2659+ will delete the file self._tmp_dir/foo
2660+ """
2661+ if self._is_localhost:
2662+ return
2663+
2664+ # Assume we can trust this file list.
2665+ paths = [os.path.join(self._tmp_dir, os.path.basename(f)) for f in path_list]
2666+ self._call('rm', *paths)
2667+
2668+
2669+
2670+
2671+
2672
2673=== added file 'src/sextant/test_all.sh'
2674--- src/sextant/test_all.sh 1970-01-01 00:00:00 +0000
2675+++ src/sextant/test_all.sh 2014-10-23 12:33:12 +0000
2676@@ -0,0 +1,4 @@
2677+#!/usr/bin/bash
2678+
2679+PYTHONPATH=$PYTHONPATH:~/.
2680+python -m unittest discover --pattern=test_*.py
2681
2682=== added file 'src/sextant/test_csvwriter.py'
2683--- src/sextant/test_csvwriter.py 1970-01-01 00:00:00 +0000
2684+++ src/sextant/test_csvwriter.py 2014-10-23 12:33:12 +0000
2685@@ -0,0 +1,89 @@
2686+#!/usr/bin/python
2687+import unittest
2688+from csvwriter import CSVWriter
2689+import subprocess
2690+from os import listdir
2691+
2692+class TestSequence(unittest.TestCase):
2693+ def get_writer(self, path='tmp_test', headers=['name', 'type'], split=100):
2694+ return CSVWriter(path, headers, split)
2695+
2696+ def tearDown(self):
2697+ to_rm = [f for f in listdir('.') if f.startswith('tmp_test') and f.endswith('.csv')]
2698+ if to_rm:
2699+ rc = subprocess.call(['rm'] + to_rm)
2700+ if rc:
2701+ msg = 'failed to clean'
2702+ else:
2703+ msg = 'cleaned'
2704+ print('{} {} files {}'.format(msg, len(to_rm), to_rm))
2705+
2706+ def test_headers(self):
2707+ # check that headers are being written correctly
2708+ headers = ['some', 'headers', 'to', 'check']
2709+ writer = self.get_writer(headers=headers)
2710+ writer.finish()
2711+
2712+ expected_path = 'tmp_test0.csv'
2713+ self.assertEquals(writer.file_iter().next(), expected_path)
2714+ writer_file = open('tmp_test0.csv', 'r')
2715+
2716+ self.assertEquals(writer_file.readline(), 'some,headers,to,check\n')
2717+ self.assertFalse(writer_file.readline()) # check that nothing extra is written
2718+
2719+ writer_file.close()
2720+
2721+ def test_writing(self):
2722+ # check that csv entries are written correctly, and errors
2723+ # appropriately raised for invalid input
2724+ writer = self.get_writer()
2725+
2726+ self.assertRaises(ValueError, writer.write, 'too short')
2727+ self.assertRaises(ValueError, writer.write, 'slightly', 'too', 'long')
2728+ writer.write('just', 'write')
2729+
2730+ writer.finish()
2731+
2732+ writer_file = open(writer.file_iter().next(), 'r+')
2733+
2734+ self.assertEqual(writer_file.readline(), 'name,type\n')
2735+ self.assertEqual(writer_file.readline(), 'just,write\n')
2736+ self.assertFalse(writer_file.readline())
2737+
2738+ writer_file.close()
2739+
2740+ def test_split(self):
2741+ split = 10
2742+ files = 10
2743+ writer = self.get_writer(split=split)
2744+
2745+ for i in xrange(files*(split-1)): # split-1 to account for header line
2746+ writer.write('an', 'entry')
2747+
2748+ writer.finish()
2749+
2750+ gen_count = sum(1 for f in writer.file_iter())
2751+ self.assertEqual(gen_count, files,
2752+ 'generated {} files, expected {}'
2753+ .format(gen_count, files))
2754+
2755+ for f in writer.file_iter():
2756+ with open(f, 'r+') as wf:
2757+ header_line = wf.readline()
2758+ header_expected = 'name,type\n'
2759+ self.assertEqual(header_line, header_expected,
2760+ '{} contained header {}, expected {}'
2761+ .format(f, header_line, header_expected)) # check headers
2762+
2763+ # check line count
2764+ with open(f, 'r+') as wf:
2765+ line_count = sum(1 for line in wf)
2766+ self.assertEqual(line_count, split,
2767+ '{} contained {} lines, expected {}'
2768+ .format(f, line_count, split))
2769+
2770+
2771+if __name__ == '__main__':
2772+ unittest.main()
2773+
2774+
2775
2776=== renamed file 'src/sextant/tests.py' => 'src/sextant/test_db_api.py' (properties changed: -x to +x)
2777--- src/sextant/tests.py 2014-08-14 15:23:39 +0000
2778+++ src/sextant/test_db_api.py 2014-10-23 12:33:12 +0000
2779@@ -1,3 +1,4 @@
2780+#!/usr/bin/python
2781 # -----------------------------------------
2782 # Sextant
2783 # Copyright 2014, Ensoft Ltd.
2784@@ -10,56 +11,69 @@
2785 from db_api import Function
2786 from db_api import FunctionQueryResult
2787 from db_api import SextantConnection
2788-from db_api import Validator
2789+from db_api import validate_query
2790
2791
2792 class TestFunctionQueryResults(unittest.TestCase):
2793- def setUp(self):
2794+ @classmethod
2795+ def setUpClass(cls):
2796 # we need to set up the remote database by using the neo4j_input_api
2797- self.remote_url = 'http://ensoft-sandbox:7474'
2798-
2799- self.setter_connection = SextantConnection(self.remote_url)
2800- self.program_1_name = 'testprogram'
2801- self.upload_program = self.setter_connection.new_program(self.program_1_name)
2802- self.upload_program.add_function('func1')
2803- self.upload_program.add_function('func2')
2804- self.upload_program.add_function('func3')
2805- self.upload_program.add_function('func4')
2806- self.upload_program.add_function('func5')
2807- self.upload_program.add_function('func6')
2808- self.upload_program.add_function('func7')
2809- self.upload_program.add_function_call('func1', 'func2')
2810- self.upload_program.add_function_call('func1', 'func4')
2811- self.upload_program.add_function_call('func2', 'func1')
2812- self.upload_program.add_function_call('func2', 'func4')
2813- self.upload_program.add_function_call('func3', 'func5')
2814- self.upload_program.add_function_call('func4', 'func4')
2815- self.upload_program.add_function_call('func4', 'func5')
2816- self.upload_program.add_function_call('func5', 'func1')
2817- self.upload_program.add_function_call('func5', 'func2')
2818- self.upload_program.add_function_call('func5', 'func3')
2819- self.upload_program.add_function_call('func6', 'func7')
2820-
2821- self.upload_program.commit()
2822-
2823- self.one_node_program_name = 'testprogram1'
2824- self.upload_one_node_program = self.setter_connection.new_program(self.one_node_program_name)
2825- self.upload_one_node_program.add_function('lonefunc')
2826-
2827- self.upload_one_node_program.commit()
2828+ cls.remote_url = 'http://ensoft-sandbox:7474'
2829+
2830+ cls.setter_connection = SextantConnection('ensoft-sandbox', 7474)
2831+
2832+ cls.program_1_name = 'testprogram'
2833+ cls.one_node_program_name = 'testprogram1'
2834+ cls.empty_program_name = 'testprogramblank'
2835+
2836+ # if anything failed before, delete programs now
2837+ cls.setter_connection.delete_program(cls.program_1_name)
2838+ cls.setter_connection.delete_program(cls.one_node_program_name)
2839+ cls.setter_connection.delete_program(cls.empty_program_name)
2840+
2841+
2842+ cls.upload_program = cls.setter_connection.new_program(cls.program_1_name)
2843+ cls.upload_program.add_function('func1')
2844+ cls.upload_program.add_function('func2')
2845+ cls.upload_program.add_function('func3')
2846+ cls.upload_program.add_function('func4')
2847+ cls.upload_program.add_function('func5')
2848+ cls.upload_program.add_function('func6')
2849+ cls.upload_program.add_function('func7')
2850+ cls.upload_program.add_call('func1', 'func2')
2851+ cls.upload_program.add_call('func1', 'func4')
2852+ cls.upload_program.add_call('func2', 'func1')
2853+ cls.upload_program.add_call('func2', 'func4')
2854+ cls.upload_program.add_call('func3', 'func5')
2855+ cls.upload_program.add_call('func4', 'func4')
2856+ cls.upload_program.add_call('func4', 'func5')
2857+ cls.upload_program.add_call('func5', 'func1')
2858+ cls.upload_program.add_call('func5', 'func2')
2859+ cls.upload_program.add_call('func5', 'func3')
2860+ cls.upload_program.add_call('func6', 'func7')
2861+
2862+ cls.upload_program.commit()
2863+
2864+ cls.upload_one_node_program = cls.setter_connection.new_program(cls.one_node_program_name)
2865+ cls.upload_one_node_program.add_function('lonefunc')
2866+
2867+ cls.upload_one_node_program.commit()
2868
2869- self.empty_program_name = 'testprogramblank'
2870- self.upload_empty_program = self.setter_connection.new_program(self.empty_program_name)
2871-
2872- self.upload_empty_program.commit()
2873-
2874- self.getter_connection = SextantConnection(self.remote_url)
2875-
2876- def tearDown(self):
2877- self.setter_connection.delete_program(self.upload_program.program_name)
2878- self.setter_connection.delete_program(self.upload_one_node_program.program_name)
2879- self.setter_connection.delete_program(self.upload_empty_program.program_name)
2880- del(self.setter_connection)
2881+ cls.upload_empty_program = cls.setter_connection.new_program(cls.empty_program_name)
2882+
2883+ cls.upload_empty_program.commit()
2884+
2885+ cls.getter_connection = cls.setter_connection
2886+
2887+
2888+ @classmethod
2889+ def tearDownClass(cls):
2890+ cls.setter_connection.delete_program(cls.upload_program.program_name)
2891+ cls.setter_connection.delete_program(cls.upload_one_node_program.program_name)
2892+ cls.setter_connection.delete_program(cls.upload_empty_program.program_name)
2893+
2894+ cls.setter_connection.close()
2895+ del(cls.setter_connection)
2896
2897 def test_17_get_call_paths(self):
2898 reference1 = FunctionQueryResult(parent_db=None, program_name=self.program_1_name)
2899@@ -134,7 +148,7 @@
2900
2901 def test_08_get_program_names(self):
2902 reference = {self.program_1_name, self.one_node_program_name, self.empty_program_name}
2903- self.assertEqual(reference, self.getter_connection.get_program_names())
2904+ self.assertTrue(reference.issubset(self.getter_connection.get_program_names()))
2905
2906
2907 def test_11_get_all_functions_called(self):
2908@@ -249,13 +263,13 @@
2909 self.assertIsNone(self.getter_connection.get_call_paths(self.one_node_program_name, 'notafunc', 'notafunc'))
2910
2911 def test_10_validator(self):
2912- self.assertFalse(Validator.validate(''))
2913- self.assertTrue(Validator.validate('thisworks'))
2914- self.assertTrue(Validator.validate('th1sw0rks'))
2915- self.assertTrue(Validator.validate('12345'))
2916- self.assertFalse(Validator.validate('this does not work'))
2917- self.assertTrue(Validator.validate('this_does_work'))
2918- self.assertFalse(Validator.validate("'")) # string consisting of a single quote mark
2919+ self.assertFalse(validate_query(''))
2920+ self.assertTrue(validate_query('thisworks'))
2921+ self.assertTrue(validate_query('th1sw0rks'))
2922+ self.assertTrue(validate_query('12345'))
2923+ self.assertFalse(validate_query('this does not work'))
2924+ self.assertTrue(validate_query('this_does_work'))
2925+ self.assertFalse(validate_query("'")) # string consisting of a single quote mark
2926
2927 if __name__ == '__main__':
2928- unittest.main()
2929\ No newline at end of file
2930+ unittest.main()
2931
2932=== added file 'src/sextant/test_parser.py'
2933--- src/sextant/test_parser.py 1970-01-01 00:00:00 +0000
2934+++ src/sextant/test_parser.py 2014-10-23 12:33:12 +0000
2935@@ -0,0 +1,85 @@
2936+#!/usr/bin/python
2937+from collections import defaultdict
2938+import unittest
2939+import subprocess
2940+
2941+import objdump_parser as parser
2942+
2943+DUMP_FILE = 'test_resources/parser_test.dump'
2944+
2945+class TestSequence(unittest.TestCase):
2946+ def setUp(self):
2947+ pass
2948+
2949+ def add_function(self, dct, name, typ):
2950+ self.assertFalse(name in dct, "duplicate function added: {} into {}".format(name, dct.keys()))
2951+ dct[name] = typ
2952+
2953+ def add_call(self, dct, caller, callee):
2954+ dct[caller].append(callee)
2955+
2956+ def do_parse(self, path=DUMP_FILE, sections=['.text'], ignore_ptrs=False):
2957+ functions = {}
2958+ calls = defaultdict(list)
2959+
2960+ # set the Parser to put output in local dictionaries
2961+ add_function = lambda n, t: self.add_function(functions, n, t)
2962+ add_call = lambda a, b: self.add_call(calls, a, b)
2963+
2964+ p = parser.Parser(path, sections=sections, ignore_ptrs=ignore_ptrs,
2965+ add_function=add_function, add_call=add_call)
2966+ res = p.parse()
2967+
2968+ parser.add_function = None
2969+ parser.add_call = None
2970+
2971+ return res, functions, calls
2972+
2973+
2974+ def test_open(self):
2975+ self.assertRaises(parser.FileNotFoundError, parser.Parser, file_path='rubbish file')
2976+
2977+ def test_functions(self):
2978+ # ensure that the correct functions are listed with the correct types
2979+ res, funcs, calls = self.do_parse()
2980+
2981+ for name, typ in zip(['normal', 'duplicates', 'wierd$name', 'printf', 'func_ptr_3'],
2982+ ['normal', 'normal', 'normal', 'stub', 'pointer']):
2983+ self.assertTrue(name in funcs, "'{}' not found in function dictionary".format(name))
2984+ self.assertEquals(funcs[name], typ)
2985+
2986+ self.assertFalse('__gmon_start__' in funcs, "don't see a function defined in .plt")
2987+
2988+ def test_no_ptrs(self):
2989+ # ensure that the ignore_ptrs flags is working
2990+ res, funcs, calls = self.do_parse(ignore_ptrs=True)
2991+
2992+ self.assertFalse('pointer' in funcs.values())
2993+ self.assertEqual(len(calls['normal']), 2)
2994+
2995+
2996+ def test_calls(self):
2997+ res, funcs, calls = self.do_parse()
2998+
2999+ self.assertTrue('normal' in calls['main'])
3000+ self.assertTrue('duplicates' in calls['main'])
3001+
3002+ normal_calls = sorted(['wierd$name', 'printf', 'func_ptr_3'])
3003+ self.assertEquals(sorted(calls['normal']), normal_calls)
3004+
3005+ self.assertEquals(calls['duplicates'].count('normal'), 2)
3006+ self.assertEquals(calls['duplicates'].count('printf'), 2,
3007+ "expected 2 printf calls in {}".format(calls['duplicates']))
3008+ self.assertTrue('func_ptr_4' in calls['duplicates'])
3009+ self.assertTrue('func_ptr_5' in calls['duplicates'])
3010+
3011+ def test_sections(self):
3012+ res, funcs, calls = self.do_parse(sections=['.plt', '.text'])
3013+
3014+ # check that we have got rid of the @s in the names
3015+ self.assertTrue('@' not in ''.join(funcs.keys()), "check names are extracted correctly")
3016+ self.assertTrue('__gmon_start__' in funcs, "see a function defined only in .plt")
3017+
3018+
3019+if __name__ == '__main__':
3020+ unittest.main()
3021
3022=== added directory 'src/sextant/test_resources'
3023=== added file 'src/sextant/test_resources/parser_test'
3024Binary files src/sextant/test_resources/parser_test 1970-01-01 00:00:00 +0000 and src/sextant/test_resources/parser_test 2014-10-23 12:33:12 +0000 differ
3025=== added file 'src/sextant/test_resources/parser_test.c'
3026--- src/sextant/test_resources/parser_test.c 1970-01-01 00:00:00 +0000
3027+++ src/sextant/test_resources/parser_test.c 2014-10-23 12:33:12 +0000
3028@@ -0,0 +1,57 @@
3029+// COMMENT
3030+#include<stdio.h>
3031+
3032+static int
3033+normal(int a);
3034+
3035+static int
3036+wierd$name(int a);
3037+
3038+typedef int (*pointer)(int);
3039+
3040+static int
3041+normal(int a)
3042+{
3043+ /* call a normal func,
3044+ * a stub and a pointer
3045+ */
3046+ pointer ptr = wierd$name;
3047+
3048+ wierd$name(a);
3049+ printf("%d\n", a);
3050+ ptr(a);
3051+
3052+ return (a);
3053+}
3054+
3055+static int
3056+wierd$name(int a)
3057+{
3058+ return (a);
3059+}
3060+
3061+static int
3062+duplicates(int a)
3063+{
3064+ pointer ptr1 = wierd$name;
3065+
3066+ /* check stubs don't get duplicated */
3067+ printf("first %d\n", a);
3068+ printf("second %d\n", a);
3069+
3070+ normal(a);
3071+ normal(a);
3072+
3073+ ptr1(a);
3074+ ptr1(a);
3075+
3076+ return (a);
3077+}
3078+
3079+int
3080+main(void)
3081+{
3082+ normal(1);
3083+ duplicates(1);
3084+ return (0);
3085+}
3086
3087=== added file 'src/sextant/test_resources/parser_test.dump'
3088--- src/sextant/test_resources/parser_test.dump 1970-01-01 00:00:00 +0000
3089+++ src/sextant/test_resources/parser_test.dump 2014-10-23 12:33:12 +0000
3090@@ -0,0 +1,44 @@
3091+Disassembly of section .init:
3092+080482b4 <_init>:
3093+ 80482b8: call 8048350 <__x86.get_pc_thunk.bx>
3094+ 80482cd: call 8048300 <__gmon_start__@plt>
3095+Disassembly of section .plt:
3096+080482e0 <printf@plt-0x10>:
3097+080482f0 <printf@plt>:
3098+08048300 <__gmon_start__@plt>:
3099+08048310 <__libc_start_main@plt>:
3100+Disassembly of section .text:
3101+08048320 <_start>:
3102+ 804833c: call 8048310 <__libc_start_main@plt>
3103+08048350 <__x86.get_pc_thunk.bx>:
3104+08048360 <deregister_tm_clones>:
3105+ 8048386: call *%eax
3106+08048390 <register_tm_clones>:
3107+ 80483c3: call *%edx
3108+080483d0 <__do_global_dtors_aux>:
3109+ 80483df: call 8048360 <deregister_tm_clones>
3110+080483f0 <frame_dummy>:
3111+ 804840f: call *%eax
3112+0804841d <normal>:
3113+ 8048430: call 8048458 <wierd$name>
3114+ 8048443: call 80482f0 <printf@plt>
3115+ 8048451: call *%eax
3116+08048458 <wierd$name>:
3117+08048460 <duplicates>:
3118+ 804847b: call 80482f0 <printf@plt>
3119+ 804848e: call 80482f0 <printf@plt>
3120+ 8048499: call 804841d <normal>
3121+ 80484a4: call 804841d <normal>
3122+ 80484b2: call *%eax
3123+ 80484bd: call *%eax
3124+080484c4 <main>:
3125+ 80484d4: call 804841d <normal>
3126+ 80484e0: call 8048460 <duplicates>
3127+080484f0 <__libc_csu_init>:
3128+ 80484f6: call 8048350 <__x86.get_pc_thunk.bx>
3129+ 804850e: call 80482b4 <_init>
3130+ 804853b: call *-0xf8(%ebx,%edi,4)
3131+08048560 <__libc_csu_fini>:
3132+Disassembly of section .fini:
3133+08048564 <_fini>:
3134+ 8048568: call 8048350 <__x86.get_pc_thunk.bx>
3135
3136=== added file 'src/sextant/test_sshmanager.py'
3137--- src/sextant/test_sshmanager.py 1970-01-01 00:00:00 +0000
3138+++ src/sextant/test_sshmanager.py 2014-10-23 12:33:12 +0000
3139@@ -0,0 +1,72 @@
3140+#!/usr/bin/python3
3141+import unittest
3142+import sshmanager
3143+import sshmanager
3144+import os
3145+sshmanager.TMP_DIR = '/home/benhutc/obj/csvload/src/sextant/test_resources/tmp'
3146+
3147+
3148+class TestSequence(unittest.TestCase):
3149+ def setUp(self):
3150+ self.manager = None
3151+
3152+ def tearDown(self):
3153+ if self.manager:
3154+ self.manager.close()
3155+ self.manager = None
3156+
3157+ def get_manager(self, local_port=9643, remote_host='localhost',
3158+ remote_port=9643, ssh_user=None):
3159+ return sshmanager.SSHManager(local_port, remote_host, remote_port, ssh_user)
3160+
3161+ def test_init(self):
3162+ self.assertRaises(ValueError, self.get_manager, local_port='invalid port')
3163+ self.assertRaises(ValueError, self.get_manager, remote_port='invalid port')
3164+
3165+ def test_connect(self):
3166+ # make a connection to localhost and ensure that tmp is created
3167+ self.manager = self.get_manager()
3168+ self.assertTrue(os.path.isdir(self.manager._tmp_dir))
3169+ self.manager.close()
3170+ self.assertFalse(os.path.isdir(self.manager._tmp_dir))
3171+ self.manager = None
3172+
3173+ # check connecion failure
3174+ self.assertRaises(sshmanager.SSHConnectionError, self.get_manager, remote_host='invalid host')
3175+
3176+ def test_files(self):
3177+ genuine_file = 'test_resources/parser_test.c'
3178+ genuine_file2 = 'test_resources/parser_test'
3179+ absent_file = 'absent_file'
3180+
3181+ self.manager = self.get_manager()
3182+ # check sending no files fails
3183+ self.assertRaises(ValueError, self.manager.send_to_tmp_dir, [])
3184+ # and sending an non-existent file
3185+ self.assertRaises(ValueError, self.manager.send_to_tmp_dir, [absent_file, genuine_file])
3186+
3187+ self.manager.send_to_tmp_dir([genuine_file, genuine_file2])
3188+ self.assertTrue(os.path.isfile(os.path.join(self.manager._tmp_dir, genuine_file.split('/')[-1])))
3189+ self.assertTrue(os.path.isfile(os.path.join(self.manager._tmp_dir, genuine_file2.split('/')[-1])))
3190+
3191+ self.manager.remove_from_tmp_dir([genuine_file, genuine_file2])
3192+ self.assertFalse(os.path.isfile(os.path.join(self.manager._tmp_dir,
3193+ genuine_file.split('/')[-1])))
3194+ self.assertFalse(os.path.isfile(os.path.join(self.manager._tmp_dir,
3195+ genuine_file2.split('/')[-1])))
3196+
3197+
3198+ self.manager.close()
3199+ self.manager = None
3200+
3201+
3202+if __name__ == '__main__':
3203+ # no coverage for:
3204+ # specifying ssh user
3205+ # scp failure
3206+ # an error in closing the ssh connection
3207+ # another error in closing the ssh connection
3208+ # mkdir failure
3209+ # rmdir failure
3210+ unittest.main()
3211+
3212
3213=== modified file 'src/sextant/update_db.py'
3214--- src/sextant/update_db.py 2014-09-29 14:01:39 +0000
3215+++ src/sextant/update_db.py 2014-10-23 12:33:12 +0000
3216@@ -5,72 +5,106 @@
3217 # -----------------------------------------
3218 # Given a program file to upload, or a program name to delete from the server, does the right thing.
3219
3220+from __future__ import print_function
3221+
3222 __all__ = ("upload_program", "delete_program")
3223
3224-from .db_api import SextantConnection, Validator
3225-from .objdump_parser import get_parsed_objects
3226+from .db_api import SextantConnection
3227+from .sshmanager import SSHConnectionError
3228+from .objdump_parser import Parser, run_objdump
3229 from os import path
3230+from time import time
3231+import subprocess
3232+import sys
3233
3234 import logging
3235
3236-
3237-def upload_program(user_name, file_path, db_url, display_url='',
3238- alternative_name=None, not_object_file=False):
3239- """
3240- Uploads a program to the remote database.
3241-
3242- Raises requests.exceptions.ConnectionError if the server didn't exist.
3243- Raises IOError if file_path doesn't correspond to a file.
3244- Raises ValueError if the desired alternative_name (or the default, if no
3245- alternative_name was specified) already exists in the database.
3246- :param file_path: the path to the local file we wish to upload
3247- :param db_url: the URL of the database (eg. http://localhost:7474)
3248- :param display_url: alternative URL to display instead of db_url
3249- :param alternative_name: a name to give the program to override the default
3250- :param object_file: bool(the file is an objdump text output file, rather than a compiled binary)
3251-
3252- """
3253-
3254- if not display_url:
3255- display_url = db_url
3256-
3257- # if no name is specified, use the form "<username>-<binary name>"
3258- name = alternative_name or (user_name + '-' + path.split(file_path)[-1])
3259-
3260- connection = SextantConnection(db_url)
3261-
3262- program_names = connection.get_program_names()
3263- if Validator.sanitise(name) in program_names:
3264- raise ValueError("There is already a program with name {}; "
3265- "please delete the previous one with the same name "
3266- "and retry, or rename the input file.".format(name))
3267-
3268- parsed_objects = get_parsed_objects(filepath=file_path,
3269- sections_to_view=['.text'],
3270- not_object_file=not_object_file,
3271- ignore_function_pointers=False)
3272-
3273- logging.info('Objdump has parsed!')
3274-
3275- program_representation = connection.new_program(Validator.sanitise(name))
3276-
3277- for obj in parsed_objects:
3278- for called in obj.what_do_i_call:
3279- if not program_representation.add_function_call(obj.name, called[-1]): # called is a tuple (address, name)
3280- logging.error('Validation error: {} calling {}'.format(obj.name, called[-1]))
3281-
3282- logging.info('Sending {} named objects to server {}...'.format(len(parsed_objects), display_url))
3283- program_representation.commit()
3284- logging.info('Successfully added {}.'.format(name))
3285-
3286-
3287-def delete_program(program_name, db_url):
3288- """
3289- Deletes a program with the specified name from the database.
3290- :param program_name: the name of the program to delete
3291- :param db_url: the URL of the database (eg. http://localhost:7474)
3292- :return: bool(success)
3293- """
3294- connection = SextantConnection(db_url)
3295+def upload_program(connection, user_name, file_path, program_name=None,
3296+ not_object_file=False):
3297+ """
3298+ Upload a program's functions and call graph to the database.
3299+
3300+ Arguments:
3301+ connection:
3302+ The SextantConnection object that manages the database connection.
3303+ user_name:
3304+ The user name of the user uploading the program.
3305+ file_path:
3306+ The path to either: the output of objdump (if not_object_file is
3307+ True) OR to a binary file if (not_object_file is False).
3308+ program_name:
3309+ An optional name to give the program in the database, if not
3310+ specified then <user_name>-<file name> will be used.
3311+ not_object_file:
3312+ Flag controlling whether file_path is pointing to a dump file or
3313+ a binary file.
3314+ """
3315+ if not connection._ssh:
3316+ raise SSHConnectionError('An SSH connection is required for '
3317+ 'program upload.')
3318+
3319+ if not program_name:
3320+ file_no_ext = path.basename(file_path).split('.')[0]
3321+ program_name = '{}-{}'.format(user_name, file_no_ext)
3322+
3323+
3324+ if program_name in connection.get_program_names():
3325+ raise ValueError('A program with name `{}` already exists in the database'
3326+ .format(program_name))
3327+
3328+
3329+ print('Uploading `{}` to the database. '
3330+ 'This may take some time for larger programs.'
3331+ .format(program_name))
3332+ start = time()
3333+
3334+ if not not_object_file:
3335+ print('Generating dump file...', end='')
3336+ sys.stdout.flush()
3337+ file_path, file_object = run_objdump(file_path)
3338+ print('done.')
3339+ else:
3340+ file_object = None
3341+
3342+ # Make parser and wire to DBprogram.
3343+ with connection.new_program(program_name) as program:
3344+
3345+ def start_parser(program):
3346+ print('Parsing dump file...', end='')
3347+ sys.stdout.flush()
3348+
3349+ def finish_parser(parser, program):
3350+ # Callback to make sure the program's csv files are flushed when
3351+ # the parser completes.
3352+ program.func_writer.finish()
3353+ program.call_writer.finish()
3354+
3355+ print('done: {} functions and {} calls.'
3356+ .format(parser.function_count, parser.call_count))
3357+
3358+ parser = Parser(file_path = file_path, file_object = file_object,
3359+ sections=[],
3360+ add_function = program.add_function,
3361+ add_call = program.add_call,
3362+ started=lambda parser: start_parser(program),
3363+ finished=lambda parser: finish_parser(parser, program))
3364+ parser.parse()
3365+
3366+ program.commit()
3367+
3368+ end = time()
3369+ print('Finished in {:.2f}s.'.format(end-start))
3370+
3371+
3372+def delete_program(connection, program_name):
3373+ """
3374+ Remove the specified program from the database.
3375+
3376+ Arguments:
3377+ connection:
3378+ The SextantConnection object managing the database connection.
3379+ program_name:
3380+ The name of the program to remove from the database.
3381+ """
3382 connection.delete_program(program_name)
3383- print('Successfully deleted {}.'.format(program_name))
3384+
3385
3386=== modified file 'src/sextant/web/server.py'
3387--- src/sextant/web/server.py 2014-10-03 11:47:52 +0000
3388+++ src/sextant/web/server.py 2014-10-23 12:33:12 +0000
3389@@ -26,7 +26,8 @@
3390
3391 from cgi import escape # deprecated in Python 3 in favour of html.escape, but we're stuck on Python 2
3392
3393-database_url = None # the URL to access the database instance
3394+# global SextantConnection object which deals with the port forwarding
3395+CONNECTION = None
3396
3397 RESPONSE_CODE_OK = 200
3398 RESPONSE_CODE_BAD_REQUEST = 400
3399@@ -67,25 +68,6 @@
3400
3401 class SVGRenderer(Resource):
3402
3403- def error_creating_neo4j_connection(self, failure):
3404- self.write("Error creating Neo4J connection: %s\n") % failure.getErrorMessage()
3405-
3406- @staticmethod
3407- def create_neo4j_connection():
3408- return db_api.SextantConnection(database_url)
3409-
3410- @staticmethod
3411- def check_program_exists(connection, name):
3412- return connection.check_program_exists(name)
3413-
3414- @staticmethod
3415- def get_whole_program(connection, name):
3416- return connection.get_whole_program(name)
3417-
3418- @staticmethod
3419- def get_functions_calling(connection, progname, funcname):
3420- return connection.get_all_functions_calling(progname, funcname)
3421-
3422 @staticmethod
3423 def get_plot(program, suppress_common_functions=False, remove_self_calls=False):
3424 graph_dot = export.ProgramConverter.to_dot(program, suppress_common_functions,
3425@@ -111,7 +93,7 @@
3426 res_msg = None # set this in the logic
3427
3428 #
3429- # Get program name and database connection, check if program exists
3430+ # Check if provided program name exists
3431 #
3432
3433 name = args.get('program_name', [None])[0]
3434@@ -121,16 +103,7 @@
3435 res_msg = "Supply 'program_name' parameter."
3436
3437 if res_code is RESPONSE_CODE_OK:
3438- try:
3439- conn = yield deferToThread(self.create_neo4j_connection)
3440- except requests.exceptions.ConnectionError:
3441- res_code = RESPONSE_CODE_BAD_GATEWAY
3442- res_fmt = "Could not reach Neo4j server at {}"
3443- res_msg = res_fmt.format(database_url)
3444- conn = None
3445-
3446- if res_code is RESPONSE_CODE_OK:
3447- exists = yield deferToThread(self.check_program_exists, conn, name)
3448+ exists = yield deferToThread(CONNECTION.check_program_exists, name)
3449 if not exists:
3450 res_code = RESPONSE_CODE_NOT_FOUND
3451 res_fmt = "Program {} not found in database."
3452@@ -146,28 +119,23 @@
3453 # look for in request.args, both tuples
3454 queries = {
3455 'whole_program': (
3456- self.get_whole_program,
3457- (conn, name),
3458+ CONNECTION.get_whole_program,
3459 ()
3460 ),
3461 'functions_calling': (
3462- self.get_functions_calling,
3463- (conn, name),
3464+ CONNECTION.get_all_functions_calling,
3465 ('func1',)
3466 ),
3467 'functions_called_by': (
3468- conn.get_all_functions_called,
3469- (name,),
3470+ CONNECTION.get_all_functions_called,
3471 ('func1',)
3472 ),
3473 'all_call_paths': (
3474- conn.get_call_paths,
3475- (name,),
3476+ CONNECTION.get_call_paths,
3477 ('func1', 'func2')
3478 ),
3479 'shortest_call_path': (
3480- conn.get_shortest_path_between_functions,
3481- (name,),
3482+ CONNECTION.get_shortest_path_between_functions,
3483 ('func1', 'func2')
3484 )
3485 }
3486@@ -186,7 +154,7 @@
3487
3488 # extract any required keyword arguments from request.args
3489 if res_code is RESPONSE_CODE_OK:
3490- fn, known_args, kwargs = query
3491+ fn, kwargs = query
3492
3493 # all args will be strings - use None to indicate missing argument
3494 req_args = tuple(args.get(kwarg, [None])[0] for kwarg in kwargs)
3495@@ -202,9 +170,8 @@
3496 # if we are okay here we have a valid query with all required arguments
3497 if res_code is RESPONSE_CODE_OK:
3498 try:
3499- all_args = known_args + req_args
3500 program = yield defer_to_thread_with_timeout(render_timeout, fn,
3501- *all_args)
3502+ name, *req_args)
3503 except defer.CancelledError:
3504 # the timeout has fired and cancelled the request
3505 res_code = RESPONSE_CODE_BAD_REQUEST
3506@@ -247,16 +214,12 @@
3507 class GraphProperties(Resource):
3508
3509 @staticmethod
3510- def _get_connection():
3511- return db_api.SextantConnection(database_url)
3512-
3513- @staticmethod
3514- def _get_program_names(connection):
3515- return connection.get_program_names()
3516-
3517- @staticmethod
3518- def _get_function_names(connection, program_name):
3519- return connection.get_function_names(program_name)
3520+ def _get_program_names():
3521+ return CONNECTION.get_program_names()
3522+
3523+ @staticmethod
3524+ def _get_function_names(program_name):
3525+ return CONNECTION.get_function_names(program_name)
3526
3527 @defer.inlineCallbacks
3528 def _render_GET(self, request):
3529@@ -269,18 +232,9 @@
3530
3531 query = request.args['query'][0]
3532
3533- try:
3534- neo4j_connection = yield deferToThread(self._get_connection)
3535- except Exception:
3536- request.setResponseCode(502) # Bad Gateway
3537- request.write("Could not reach Neo4j server at {}.".format(database_url))
3538- request.finish()
3539- defer.returnValue(None)
3540- neo4j_connection = None # just to silence the "referenced before assignment" warnings
3541-
3542 if query == 'programs':
3543 request.setHeader("content-type", "application/json")
3544- prognames = yield deferToThread(self._get_program_names, neo4j_connection)
3545+ prognames = yield deferToThread(self._get_program_names)
3546 request.write(json.dumps(list(prognames)))
3547 request.finish()
3548 defer.returnValue(None)
3549@@ -294,7 +248,7 @@
3550 defer.returnValue(None)
3551 program_name = request.args['program_name'][0]
3552
3553- funcnames = yield deferToThread(self._get_function_names, neo4j_connection, program_name)
3554+ funcnames = yield deferToThread(self._get_function_names, program_name)
3555 if funcnames is None:
3556 request.setResponseCode(404)
3557 request.setHeader("content-type", "text/plain")
3558@@ -319,10 +273,12 @@
3559 return NOT_DONE_YET
3560
3561
3562-def serve_site(input_database_url='http://localhost:7474', port=2905):
3563-
3564- global database_url
3565- database_url = input_database_url
3566+def serve_site(connection, port):
3567+ global CONNECTION
3568+
3569+ CONNECTION = connection
3570+
3571+
3572 # serve static directory at root
3573 root = File(os.path.join(environment.RESOURCES_DIR, 'sextant', 'web'))
3574

Subscribers

People subscribed via source and target branches