Merge lp:~cjwatson/canonical-identity-provider/api-accounts-bulk into lp:canonical-identity-provider/release

Proposed by Colin Watson
Status: Merged
Approved by: Colin Watson
Approved revision: no longer in the source branch.
Merge reported by: Otto Co-Pilot
Merged at revision: not available
Proposed branch: lp:~cjwatson/canonical-identity-provider/api-accounts-bulk
Merge into: lp:canonical-identity-provider/release
Diff against target: 975 lines (+571/-68)
14 files modified
src/api/v20/auth.py (+10/-1)
src/api/v20/authorization.py (+18/-13)
src/api/v20/handlers.py (+34/-3)
src/api/v20/tests/test_handlers.py (+94/-0)
src/api/v20/tests/test_utils.py (+62/-0)
src/api/v20/urls.py (+7/-1)
src/api/v20/utils.py (+75/-26)
src/identityprovider/models/account.py (+29/-13)
src/identityprovider/models/person.py (+45/-2)
src/identityprovider/signals.py (+16/-1)
src/identityprovider/tests/test_models_account.py (+10/-1)
src/identityprovider/tests/test_models_person.py (+40/-1)
src/identityprovider/tests/test_signals.py (+65/-6)
src/testing/helpers.py (+66/-0)
To merge this branch: bzr merge lp:~cjwatson/canonical-identity-provider/api-accounts-bulk
Reviewer Review Type Date Requested Status
Natalia Bidart (community) Approve
Review via email: mp+335393@code.launchpad.net

Commit message

Add an /api/v2/accounts-bulk endpoint.

This fetches information about a set of accounts in bulk by OpenID
identifier, useful for other services that need to render pages related to
multiple accounts.

Description of the change

This is part of https://docs.google.com/document/d/1K7Dxm5RuKgioKVSyhc5uUefsM4UH91pRUfuYrUe5hCg.

Getting the bulk database queries right (which I think is important since the set of accounts being requested could quite reasonably be 50-100 or so, and we don't want to scale the query count linearly in all their email addresses) was really quite complicated, and I'd welcome feedback from Django experts on how I went about it. My general approach was to add methods that fetch objects by a set of accounts to each of the appropriate managers, and to convert some of the individual properties into bulk prefetching methods; but there were quite significant complications around Account.preferredemail and Account.is_verified.

To post a comment you must log in.
Revision history for this message
Simon Davy (bloodearnest) wrote :

LGTM, left a few minor comments.

Also, I wonder if it's worth adding assertNumQueries checks to some of the bulk query tests?

https://docs.djangoproject.com/en/dev/topics/testing/tools/#django.test.TransactionTestCase.assertNumQueries

Revision history for this message
Colin Watson (cjwatson) :
Revision history for this message
Natalia Bidart (nataliabidart) :
review: Needs Information
Revision history for this message
Colin Watson (cjwatson) :
Revision history for this message
Natalia Bidart (nataliabidart) wrote :

Thanks for the changes! I think the code is ready for landing, except for the following:

* a few nitpicks added as inline comments
* a few tests missing, for all the changes/addings to src/identityprovider/models/account.py, src/identityprovider/models/person.py and src/identityprovider/signals.py (the changes to signals can just be tested in test_models_account.py by deleting emails and asserting over the preferred email result).

Thanks!

Revision history for this message
Natalia Bidart (nataliabidart) wrote :

Thank you! Looks great.

review: Approve

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1=== modified file 'src/api/v20/auth.py'
2--- src/api/v20/auth.py 2016-03-18 18:40:31 +0000
3+++ src/api/v20/auth.py 2018-01-23 17:52:15 +0000
4@@ -1,4 +1,4 @@
5-# Copyright 2010-2012 Canonical Ltd. This software is licensed under the
6+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
7 # GNU Affero General Public License version 3 (see the file LICENSE).
8 import json
9 import logging
10@@ -72,6 +72,15 @@
11 return result
12
13
14+class ApiAccountsBulkAuthentication(SSOOAuthAuthentication):
15+
16+ def challenge(self):
17+ """Return a json body 401 response."""
18+ response = errors.INVALID_CREDENTIALS()
19+ response['WWW-Authenticate'] = 'OAuth realm="%s"' % self.realm
20+ return response
21+
22+
23 class ApiAccountRegistrationAuthentication(ApiOAuthAuthentication):
24
25 def _has_ownership(self, request):
26
27=== modified file 'src/api/v20/authorization.py'
28--- src/api/v20/authorization.py 2016-03-15 18:38:16 +0000
29+++ src/api/v20/authorization.py 2018-01-23 17:52:15 +0000
30@@ -1,11 +1,11 @@
31-# Copyright 2010-2016 Canonical Ltd. This software is licensed under the
32+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
33 # GNU Affero General Public License version 3 (see the file LICENSE).
34 import functools
35
36 from api.v20.utils import errors
37
38
39-def is_authorized(request):
40+def is_authorized(request, edit=None):
41 """Checks if logged in user has permissions based on request method.
42
43 This helper method's main use is for keeping backwards compatibility
44@@ -13,21 +13,26 @@
45
46 New API views should check permissions with @authorization_required.
47 """
48- if request.method == 'GET':
49+ if edit is None:
50+ edit = request.method != 'GET'
51+ if edit:
52+ perm = 'identityprovider.api_edit_account_details'
53+ else:
54 perm = 'identityprovider.api_view_account_details'
55- else:
56- perm = 'identityprovider.api_edit_account_details'
57 account = request.user
58 return account.user.has_perm(perm)
59
60
61-def authorization_required(func):
62+def authorization_required(edit=None):
63 """View decorator that checks logged in user has specific permissions."""
64
65- @functools.wraps(func)
66- def wrapped(self, request, *args, **kwargs):
67- if not is_authorized(request):
68- return errors.FORBIDDEN()
69- return func(self, request, *args, **kwargs)
70-
71- return wrapped
72+ def wrapper(func):
73+ @functools.wraps(func)
74+ def wrapped(self, request, *args, **kwargs):
75+ if not is_authorized(request, edit=edit):
76+ return errors.FORBIDDEN()
77+ return func(self, request, *args, **kwargs)
78+
79+ return wrapped
80+
81+ return wrapper
82
83=== modified file 'src/api/v20/handlers.py'
84--- src/api/v20/handlers.py 2016-12-15 21:46:04 +0000
85+++ src/api/v20/handlers.py 2018-01-23 17:52:15 +0000
86@@ -1,4 +1,4 @@
87-# Copyright 2010-2016 Canonical Ltd. This software is licensed under the
88+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
89 # GNU Affero General Public License version 3 (see the file LICENSE).
90
91 import json
92@@ -27,6 +27,7 @@
93 from api.v20.utils import (
94 errors,
95 get_account_data,
96+ get_accounts_data,
97 get_email_data,
98 get_minimal_account_data,
99 get_token_data,
100@@ -148,7 +149,7 @@
101 expand = (request.GET.get('expand', '').lower() == 'true')
102 return get_account_data(account, expand=expand)
103
104- @authorization_required
105+ @authorization_required(edit=True)
106 @require_mime('json')
107 @add_piston_http_patch_support
108 @validate(AccountStatusForm, 'data')
109@@ -164,6 +165,36 @@
110 return get_account_data(account)
111
112
113+class AccountsBulkHandler(BaseHandler):
114+ allowed_methods = ('POST',)
115+
116+ # no throttle limit since this endpoint is used by privileged accounts only
117+ @authorization_required(edit=False)
118+ @require_mime('json')
119+ def create(self, request):
120+ """Get information about a set of accounts.
121+
122+ We use POST for this to avoid needing to worry about practical
123+ limits on the length of the query string.
124+
125+ """
126+ data = request.data
127+
128+ try:
129+ openids = data['openids']
130+ except KeyError:
131+ missing = {'openids': [FIELD_REQUIRED]}
132+ return errors.INVALID_DATA(new_style=True, **missing)
133+ expand = data.get('expand', False)
134+
135+ accounts = Account.objects.prefetch_related(
136+ 'emailaddress_set',
137+ 'token_set',
138+ ).filter(
139+ openid_identifier__in=openids, status__exact=AccountStatus.ACTIVE)
140+ return get_accounts_data(accounts, expand=expand)
141+
142+
143 class PasswordResetTokenHandler(BaseHandler):
144 allowed_methods = ('POST',)
145
146@@ -596,7 +627,7 @@
147 result = errors.RESOURCE_NOT_FOUND()
148 return result
149
150- @authorization_required
151+ @authorization_required(edit=True)
152 @require_mime('json')
153 @add_piston_http_patch_support
154 @validate(AccountStatusForm, 'data')
155
156=== modified file 'src/api/v20/tests/test_handlers.py'
157--- src/api/v20/tests/test_handlers.py 2016-12-07 16:28:36 +0000
158+++ src/api/v20/tests/test_handlers.py 2018-01-23 17:52:15 +0000
159@@ -48,6 +48,7 @@
160 from identityprovider.tests.test_auth import AuthLogTestCaseMixin
161 from identityprovider.tests.utils import SSOBaseTestCase, TimelineActionMixin
162 from identityprovider.utils import redirection_url_for_token
163+from testing.helpers import assert_no_extra_queries_after
164
165
166 OVERRIDES = dict(
167@@ -328,6 +329,99 @@
168 self.assert_account_data_expanded(' ', False)
169
170
171+class AccountsBulkHandlerTestCase(BaseTestCase):
172+
173+ url = reverse('api-accounts-bulk')
174+
175+ def setUp(self):
176+ super(AccountsBulkHandlerTestCase, self).setUp()
177+ account = self.factory.make_account(
178+ email='super@user.com', permissions=['api_view_account_details'])
179+ self.super_token = account.create_oauth_token('super-token')
180+
181+ def test_any_server_error_is_json(self):
182+ self.assert_any_server_error_is_json(self.do_post, token=self.token)
183+
184+ def test_post_unauthenticated_401(self):
185+ self.do_post(status_code=401)
186+ self.do_post({'openids': []}, status_code=401)
187+
188+ def test_post_authenticated_with_django_401(self):
189+ assert self.client.login(
190+ username='super@user.com', password=DEFAULT_USER_PASSWORD)
191+ self.do_post(status_code=401)
192+ self.do_post({'openids': []}, status_code=401)
193+
194+ def test_post_authenticated_with_oauth_unprivileged(self):
195+ json_body = self.do_post(
196+ {'openids': []}, token=self.token, status_code=403)
197+ self.assertEqual(json_body['code'], 'FORBIDDEN')
198+ self.assertEqual(json_body['extra'], {})
199+
200+ def test_post_authenticated_no_match(self):
201+ json_body = self.do_post({'openids': []}, token=self.super_token)
202+ self.assertEqual(json_body, self.to_json({}))
203+
204+ def test_post_authenticated_match(self):
205+ accounts = [self.factory.make_account() for _ in range(4)]
206+ json_body = self.do_post(
207+ {
208+ 'openids': [
209+ account.openid_identifier for account in accounts[:3]
210+ ],
211+ },
212+ token=self.super_token)
213+
214+ expected = handlers.get_accounts_data(accounts[:3])
215+ self.assertEqual(json_body, self.to_json(expected))
216+
217+ def test_post_authenticated_match_expanded(self):
218+ self.maxDiff = None
219+ accounts = [self.factory.make_account() for _ in range(4)]
220+ json_body = self.do_post(
221+ {
222+ 'openids': [
223+ account.openid_identifier for account in accounts[:3]
224+ ],
225+ 'expand': True,
226+ },
227+ token=self.super_token)
228+
229+ expected = handlers.get_accounts_data(accounts[:3], expand=True)
230+ self.assertEqual(json_body, self.to_json(expected))
231+
232+ def test_constant_query_count(self):
233+ accounts = []
234+
235+ def add_accounts():
236+ accounts.extend([
237+ self.factory.make_account(email_validated=(i % 2) == 0)
238+ for i in range(4)
239+ ])
240+ for i, account in enumerate(accounts[-4:]):
241+ if (i % 2) == 0:
242+ self.factory.make_person(account=account)
243+ for _ in range(10):
244+ self.factory.make_email_for_account(account=account)
245+ for _ in range(10):
246+ self.factory.make_oauth_token(account=account)
247+
248+ def do_request():
249+ return self.do_post(
250+ {
251+ 'openids': [
252+ account.openid_identifier for account in accounts
253+ ],
254+ 'expand': True,
255+ },
256+ token=self.super_token)
257+
258+ add_accounts()
259+ json_body = assert_no_extra_queries_after(add_accounts, do_request)
260+ expected = handlers.get_accounts_data(accounts, expand=True)
261+ self.assertEqual(json_body, self.to_json(expected))
262+
263+
264 class AnonymousAccountRegistrationHandlerTestCase(BaseTestCase):
265
266 url = reverse('api-registration')
267
268=== modified file 'src/api/v20/tests/test_utils.py'
269--- src/api/v20/tests/test_utils.py 2016-08-05 08:44:31 +0000
270+++ src/api/v20/tests/test_utils.py 2018-01-23 17:52:15 +0000
271@@ -1,13 +1,17 @@
272 import json
273 from datetime import timedelta
274+from functools import partial
275+from operator import attrgetter
276
277 from django.core.urlresolvers import reverse
278
279 from api.v20.utils import (
280 EnsureJSONResponseOnAPIErrorMiddleware,
281 get_account_data,
282+ get_accounts_data,
283 get_email_data,
284 get_minimal_account_data,
285+ get_minimal_accounts_data,
286 get_token_data,
287 )
288 from identityprovider.models.const import EmailStatus
289@@ -41,6 +45,64 @@
290 self.assertEqual(json.loads(response.content), expected_data)
291
292
293+class PluralTestCase(SSOBaseTestCase):
294+
295+ def assert_equivalent_plural(self, plural_func, singular_func, objects,
296+ get_key):
297+ """Assert that a plural function matches its singular sibling.
298+
299+ `plural_func` applied to all objects at once produces the same
300+ results as `singular_func` applied to each object in turn.
301+ """
302+ plural_results = plural_func(objects)
303+ singular_results = {
304+ get_key(obj): singular_func(obj) for obj in objects
305+ }
306+ self.assertEqual(plural_results, singular_results)
307+
308+
309+class GetAccountsDataTestCase(PluralTestCase):
310+
311+ def make_accounts(self, count):
312+ return [
313+ self.factory.make_account(email_validated=(i % 2) == 0)
314+ for i in range(count)
315+ ]
316+
317+ def test_get_minimal_accounts_data(self):
318+ accounts = self.make_accounts(4)
319+ self.factory.make_person(account=accounts[0])
320+
321+ self.assert_equivalent_plural(
322+ get_minimal_accounts_data, get_minimal_account_data,
323+ accounts[:3], attrgetter('openid_identifier'))
324+
325+ def test_get_accounts_data(self):
326+ accounts = self.make_accounts(4)
327+ self.factory.make_person(account=accounts[0])
328+ for _ in range(15):
329+ self.factory.make_email_for_account(account=accounts[1])
330+ for _ in range(15):
331+ self.factory.make_oauth_token(account=accounts[2])
332+
333+ self.assert_equivalent_plural(
334+ get_accounts_data, partial(get_account_data, include_tokens=False),
335+ accounts[:3], attrgetter('openid_identifier'))
336+ self.assert_equivalent_plural(
337+ partial(get_accounts_data, limit=5),
338+ partial(get_account_data, limit=5, include_tokens=False),
339+ accounts[:3], attrgetter('openid_identifier'))
340+ self.assert_equivalent_plural(
341+ partial(get_accounts_data, expand=True),
342+ partial(get_account_data, expand=True, include_tokens=False),
343+ accounts[:3], attrgetter('openid_identifier'))
344+ self.assert_equivalent_plural(
345+ partial(get_accounts_data, limit=5, expand=True),
346+ partial(
347+ get_account_data, limit=5, expand=True, include_tokens=False),
348+ accounts[:3], attrgetter('openid_identifier'))
349+
350+
351 class GetAccountDataTestCase(SSOBaseTestCase):
352
353 openid = '1234567'
354
355=== modified file 'src/api/v20/urls.py'
356--- src/api/v20/urls.py 2016-11-01 10:53:35 +0000
357+++ src/api/v20/urls.py 2018-01-23 17:52:15 +0000
358@@ -1,10 +1,11 @@
359-# Copyright 2010 Canonical Ltd. This software is licensed under the
360+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
361 # GNU Affero General Public License version 3 (see the file LICENSE).
362
363 from django.conf.urls import patterns, url
364
365 from api.v20.auth import (
366 ApiAccountsAuthentication,
367+ ApiAccountsBulkAuthentication,
368 ApiAccountRegistrationAuthentication,
369 ApiEmailsAuthentication,
370 ApiTokensAuthentication,
371@@ -13,6 +14,7 @@
372 from api.v20.handlers import (
373 AccountLoginHandler,
374 AccountRegistrationHandler,
375+ AccountsBulkHandler,
376 AccountsHandler,
377 EmailsHandler,
378 MacaroonDischargeHandler,
379@@ -26,6 +28,9 @@
380
381 v2accounts = ApiResource(
382 handler=AccountsHandler, authentication=ApiAccountsAuthentication())
383+v2accounts_bulk = ApiResource(
384+ handler=AccountsBulkHandler,
385+ authentication=ApiAccountsBulkAuthentication())
386 v2emails = ApiResource(
387 handler=EmailsHandler, authentication=ApiEmailsAuthentication())
388 v2login = ApiResource(handler=AccountLoginHandler)
389@@ -44,6 +49,7 @@
390 '',
391 url(r'^accounts$', v2registration, name='api-registration'),
392 url(r'^accounts/(?P<openid>\w+)$', v2accounts, name='api-account'),
393+ url(r'^accounts-bulk$', v2accounts_bulk, name='api-accounts-bulk'),
394 url(r'^emails/(?P<email>.+)$', v2emails, name='api-email'),
395 url(r'^requests/validate$', v2requests, name='api-requests'),
396 url(r'^tokens/discharge$', v2macaroon_discharge,
397
398=== modified file 'src/api/v20/utils.py'
399--- src/api/v20/utils.py 2016-09-14 16:47:06 +0000
400+++ src/api/v20/utils.py 2018-01-23 17:52:15 +0000
401@@ -1,12 +1,18 @@
402-# Copyright 2010-2012 Canonical Ltd. This software is licensed under the
403+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
404 # GNU Affero General Public License version 3 (see the file LICENSE).
405+
406 import json
407+from operator import attrgetter
408
409 from django.core.urlresolvers import reverse
410 from django.http import HttpResponse
411 from django.utils.translation import ugettext_lazy as _
412
413-from identityprovider.models.const import AccountStatus
414+from identityprovider.models.account import Account
415+from identityprovider.models.const import (
416+ AccountStatus,
417+ TokenScope,
418+)
419
420
421 class EnsureJSONResponseOnAPIErrorMiddleware(object):
422@@ -180,40 +186,83 @@
423 return unicode(AccountStatus._verbose[status])
424
425
426+def get_minimal_accounts_data(accounts):
427+ """Return the minimal non-private data for a sequence of accounts."""
428+ Account.objects.prefetch_person(accounts)
429+ # Given that we've prefetched related EmailAddress rows already, this is
430+ # equivalent to account.is_verified for each account: an account is
431+ # verified if it has at least one PREFERRED or VALIDATED email address.
432+ is_verified = {
433+ account: (
434+ account.preferredemail is not None and
435+ account.preferredemail.is_verified)
436+ for account in accounts
437+ }
438+ return {
439+ account.openid_identifier: {
440+ 'href': reverse('api-account', args=(account.openid_identifier,)),
441+ 'openid': account.openid_identifier,
442+ 'verified': is_verified[account],
443+ 'username': account.person_name,
444+ }
445+ for account in accounts
446+ }
447+
448+
449 def get_minimal_account_data(account):
450 """Return the minimal non-private data for the account."""
451- href = reverse('api-account', args=(account.openid_identifier,))
452- data = dict(
453- href=href,
454- openid=account.openid_identifier,
455- verified=account.is_verified,
456- username=account.person_name,
457- )
458+ return get_minimal_accounts_data([account])[account.openid_identifier]
459+
460+
461+def get_accounts_data(accounts, limit=10, expand=False, include_tokens=False):
462+ """Get the relevant data from a sequence of accounts to be serialized.
463+
464+ Emails and tokens will be limited to the latest 10 results for each
465+ account, where latest for email means creation date, and for tokens
466+ means updated date.
467+
468+ """
469+ data = get_minimal_accounts_data(accounts)
470+
471+ for account in accounts:
472+ # Don't use helpers such as Account.oauth_tokens or directly slice
473+ # QuerySets here, since that would invalidate previous
474+ # prefetch_related calls. Iterating over the whole
475+ # EmailAddress/Token sets shouldn't be too bad in this case.
476+ emails = sorted(
477+ account.emailaddress_set.all(),
478+ key=attrgetter('date_created'), reverse=True)[:limit]
479+ data[account.openid_identifier].update({
480+ 'email': (
481+ None if account.preferredemail is None
482+ else account.preferredemail.email),
483+ 'displayname': account.displayname,
484+ 'status': _get_account_status_text(account.status),
485+ 'emails': [get_email_data(e, expand=expand) for e in emails],
486+ })
487+ if include_tokens:
488+ tokens = sorted(
489+ (token for token in account.token_set.all()
490+ if token.scope == TokenScope.VERSION_2),
491+ key=attrgetter('date_updated'), reverse=True)[:limit]
492+ data[account.openid_identifier]['tokens'] = [
493+ get_token_data(t, expand=expand) for t in tokens]
494+
495 return data
496
497
498-def get_account_data(account, limit=10, expand=False):
499+# When fetching a single account, include_tokens defaults to True for
500+# historical reasons, although the tokens should now be mostly unused.
501+def get_account_data(account, limit=10, expand=False, include_tokens=True):
502 """Get the relevant data from an account to be serialized.
503
504 Emails and tokens will be limited to the latest 10 results, where latest
505- for email means creation date, and for tokens mean updated date.
506+ for email means creation date, and for tokens means updated date.
507
508 """
509- data = get_minimal_account_data(account)
510- email = None
511- if account.preferredemail is not None:
512- email = account.preferredemail.email
513-
514- emails = account.emails(limit=limit)
515- tokens = account.oauth_tokens(limit=limit)
516- data.update(
517- email=email,
518- displayname=account.displayname,
519- status=_get_account_status_text(account.status),
520- emails=[get_email_data(e, expand) for e in emails],
521- tokens=[get_token_data(t, expand) for t in tokens],
522- )
523- return data
524+ return get_accounts_data(
525+ [account], limit=limit, expand=expand,
526+ include_tokens=include_tokens)[account.openid_identifier]
527
528
529 def get_email_data(email, expand=True):
530
531=== modified file 'src/identityprovider/models/account.py'
532--- src/identityprovider/models/account.py 2017-07-20 14:22:30 +0000
533+++ src/identityprovider/models/account.py 2018-01-23 17:52:15 +0000
534@@ -1,4 +1,4 @@
535-# Copyright 2010-2016 Canonical Ltd. This software is licensed under
536+# Copyright 2010-2018 Canonical Ltd. This software is licensed under
537 # the GNU Affero General Public License version 3 (see the file
538 # LICENSE).
539
540@@ -6,6 +6,7 @@
541 from __future__ import unicode_literals
542
543 import logging
544+from operator import attrgetter
545
546 from django.conf import settings
547 from django.contrib.auth import get_backends
548@@ -147,6 +148,13 @@
549 return get_object_or_none(self, openid_identifier=openid_identifier,
550 status=AccountStatus.ACTIVE)
551
552+ def prefetch_person(self, accounts):
553+ """Pre-fetch person for a set of accounts."""
554+ from identityprovider.models.person import Person
555+ for account, person in Person.objects.get_by_account(
556+ accounts).items():
557+ account.__dict__['person'] = person
558+
559
560 class DisplaynameField(models.TextField):
561 def __init__(self, null=False, **kwargs):
562@@ -317,18 +325,26 @@
563 if email is None or not email.is_verified:
564 try:
565 email = None
566- account_emails = self.emailaddress_set.filter(
567- status=EmailStatus.PREFERRED)
568- if account_emails.count() > 0:
569- email = account_emails[0]
570+ # We assume that accounts only ever have a smallish number
571+ # of email addresses; with that assumption, doing this
572+ # rather than filtering on separate statuses is more
573+ # prefetch-friendly and involves fewer DB round-trips.
574+ all_emails = sorted(
575+ self.emailaddress_set.all(),
576+ key=attrgetter('date_created'))
577+ emails = [
578+ e for e in all_emails if e.status == EmailStatus.PREFERRED]
579+ if emails:
580+ email = emails[0]
581
582 if not email:
583 # Try to determine a suitable address, and mark it
584 # as preferred.
585- emails = self.emailaddress_set.filter(
586- status=EmailStatus.VALIDATED)
587- if emails.count() > 0:
588- email = emails.order_by('date_created')[0]
589+ emails = [
590+ e for e in all_emails
591+ if e.status == EmailStatus.VALIDATED]
592+ if emails:
593+ email = emails[0]
594 email.status = EmailStatus.PREFERRED
595 email.save()
596 logger.info(
597@@ -339,10 +355,10 @@
598 if not email:
599 # we have no validated email, so use the first NEW email
600 # but don't save it
601- emails = self.emailaddress_set.filter(
602- status=EmailStatus.NEW)
603- if emails.count() > 0:
604- email = emails.order_by('date_created')[0]
605+ emails = [
606+ e for e in all_emails if e.status == EmailStatus.NEW]
607+ if emails:
608+ email = emails[0]
609
610 self._preferredemail = email
611 except:
612
613=== modified file 'src/identityprovider/models/person.py'
614--- src/identityprovider/models/person.py 2016-04-28 13:01:03 +0000
615+++ src/identityprovider/models/person.py 2018-01-23 17:52:15 +0000
616@@ -1,9 +1,12 @@
617-# Copyright 2010 Canonical Ltd. This software is licensed under the
618+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
619 # GNU Affero General Public License version 3 (see the file LICENSE).
620
621 from __future__ import unicode_literals
622
623-from django.db import models
624+from django.db import (
625+ connection,
626+ models,
627+)
628
629 from identityprovider.const import PERSON_VISIBILITY_PUBLIC
630 from identityprovider.models.account import Account, LPOpenIdIdentifier
631@@ -16,6 +19,44 @@
632 )
633
634
635+class PersonManager(models.Manager):
636+ """A custom manager for Person models."""
637+
638+ def get_by_account(self, accounts):
639+ """Return Person instances for a set of accounts.
640+
641+ The return value is a dict mapping Account instances to Person
642+ instances.
643+
644+ """
645+ if not accounts:
646+ return {}
647+ accounts_by_id = {account.id: account for account in accounts}
648+ # This is rather cumbersome because we don't have Django model
649+ # relationships for the tables synced over from Launchpad.
650+ cursor = connection.cursor()
651+ cursor.execute("""
652+ SELECT account.id, lp_person.id
653+ FROM account, lp_openididentifier, lp_person
654+ WHERE
655+ lp_openididentifier.account = lp_person.account
656+ AND lp_openididentifier.identifier = account.openid_identifier
657+ AND account.id IN %s
658+ """, (tuple(accounts_by_id),))
659+ rows = cursor.fetchall()
660+ persons = self.filter(
661+ id__in={person_id for _, person_id in rows}).all()
662+ persons_by_id = {person.id: person for person in persons}
663+ result = {
664+ accounts_by_id[account_id]: persons_by_id[person_id]
665+ for account_id, person_id in rows
666+ }
667+ for account in accounts:
668+ if account not in result:
669+ result[account] = None
670+ return result
671+
672+
673 class Person(models.Model):
674 displayname = models.TextField(null=True, blank=True)
675 teamowner = models.IntegerField(db_column=b'teamowner',
676@@ -65,6 +106,8 @@
677 lp_account = models.IntegerField(null=True, db_column=b'account',
678 unique=True)
679
680+ objects = PersonManager()
681+
682 class Meta:
683 app_label = 'identityprovider'
684 db_table = u'lp_person'
685
686=== modified file 'src/identityprovider/signals.py'
687--- src/identityprovider/signals.py 2016-05-30 21:41:39 +0000
688+++ src/identityprovider/signals.py 2018-01-23 17:52:15 +0000
689@@ -1,4 +1,4 @@
690-# Copyright 2010 Canonical Ltd. This software is licensed under the
691+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
692 # GNU Affero General Public License version 3 (see the file LICENSE).
693 import logging
694 import traceback
695@@ -16,6 +16,7 @@
696 from identityprovider.const import SESSION_TOKEN_KEY, SESSION_TOKEN_NAME
697 from identityprovider.models import Account, AccountPassword, AuthLog, Token
698 from identityprovider.models.const import AuthLogType, TokenScope
699+from identityprovider.models.emailaddress import EmailAddress
700 from webservices.utils import http_request_with_timeout
701
702
703@@ -180,3 +181,17 @@
704 )
705
706 pre_delete.connect(track_isdtest_account_deletion, sender=Account)
707+
708+
709+# account.preferredemail may be prefetched for API performance, but deleting
710+# EmailAddress rows may invalidate the prefetched value.
711+def invalidate_preferredemail(sender, instance, using, **kwargs):
712+ account_cache_name = instance._meta.get_field('account').get_cache_name()
713+ account = getattr(instance, account_cache_name, None)
714+ if account is not None:
715+ try:
716+ del account._preferredemail
717+ except AttributeError:
718+ pass
719+
720+pre_delete.connect(invalidate_preferredemail, sender=EmailAddress)
721
722=== modified file 'src/identityprovider/tests/test_models_account.py'
723--- src/identityprovider/tests/test_models_account.py 2016-12-13 14:30:16 +0000
724+++ src/identityprovider/tests/test_models_account.py 2018-01-23 17:52:15 +0000
725@@ -1,4 +1,4 @@
726-# Copyright 2010-2016 Canonical Ltd. This software is licensed under the
727+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
728 # GNU Affero General Public License version 3 (see the file LICENSE).
729
730
731@@ -252,6 +252,15 @@
732 else:
733 self.assertIsNone(result)
734
735+ def test_prefetch_person(self):
736+ accounts = [self.factory.make_account() for _ in range(3)]
737+ persons = [
738+ self.factory.make_person(account=account)
739+ for account in accounts[:2]] + [None]
740+ Account.objects.prefetch_person(accounts)
741+ with self.assertNumQueries(0):
742+ self.assertEqual([account.person for account in accounts], persons)
743+
744 def test_verified_list_only_verified_accounts(self):
745 accounts = Account.objects.verified()
746 for account in accounts:
747
748=== modified file 'src/identityprovider/tests/test_models_person.py'
749--- src/identityprovider/tests/test_models_person.py 2016-04-28 13:01:03 +0000
750+++ src/identityprovider/tests/test_models_person.py 2018-01-23 17:52:15 +0000
751@@ -1,4 +1,4 @@
752-# Copyright 2010 Canonical Ltd. This software is licensed under the
753+# Copyright 2010-2018 Canonical Ltd. This software is licensed under the
754 # GNU Affero General Public License version 3 (see the file LICENSE).
755 from random import randint
756
757@@ -18,6 +18,45 @@
758 AccountStatus,
759 )
760 from identityprovider.tests.utils import SSOBaseTestCase
761+from testing.helpers import assert_no_extra_queries_after
762+
763+
764+class PersonManagerTestCase(SSOBaseTestCase):
765+
766+ def test_get_by_account_no_accounts(self):
767+ with self.assertNumQueries(0):
768+ self.assertEqual(Person.objects.get_by_account([]), {})
769+
770+ def test_get_by_account(self):
771+ accounts = [self.factory.make_account() for _ in range(10)]
772+ persons = [
773+ self.factory.make_person(account=account)
774+ for account in accounts[:5]]
775+ self.factory.make_person()
776+
777+ with self.assertNumQueries(2):
778+ self.assertEqual(
779+ Person.objects.get_by_account(accounts[:3] + accounts[5:8]),
780+ {
781+ accounts[0]: persons[0],
782+ accounts[1]: persons[1],
783+ accounts[2]: persons[2],
784+ accounts[5]: None,
785+ accounts[6]: None,
786+ accounts[7]: None,
787+ })
788+
789+ def test_get_by_account_constant_query_count(self):
790+ accounts = []
791+
792+ def add_accounts():
793+ accounts.append(self.factory.make_account())
794+ self.factory.make_person(account=accounts[-1])
795+ accounts.append(self.factory.make_account())
796+
797+ add_accounts()
798+ assert_no_extra_queries_after(
799+ add_accounts, Person.objects.get_by_account, accounts)
800
801
802 class PersonTestCase(SSOBaseTestCase):
803
804=== modified file 'src/identityprovider/tests/test_signals.py'
805--- src/identityprovider/tests/test_signals.py 2015-11-05 16:43:46 +0000
806+++ src/identityprovider/tests/test_signals.py 2018-01-23 17:52:15 +0000
807@@ -1,28 +1,37 @@
808 # -*- coding: utf-8 -*-
809
810-# Copyright 2013 Canonical Ltd. This software is licensed under the
811+# Copyright 2013-2018 Canonical Ltd. This software is licensed under the
812 # GNU Affero General Public License version 3 (see the file LICENSE).
813
814 from django.conf import settings
815 from django.contrib.auth.signals import user_logged_in
816 from django.core.urlresolvers import reverse
817-from django.db.models.signals import post_save
818+from django.db.models.signals import (
819+ post_save,
820+ pre_delete,
821+)
822 from django.utils.timezone import now
823
824 from mock import Mock, patch
825
826 from identityprovider.const import SESSION_TOKEN_KEY, SESSION_TOKEN_NAME
827-from identityprovider.models.const import AuthLogType, TokenScope
828+from identityprovider.models.const import (
829+ AuthLogType,
830+ EmailStatus,
831+ TokenScope,
832+)
833+from identityprovider.models.emailaddress import EmailAddress
834+from identityprovider.readonly import ReadOnlyManager
835 from identityprovider.signals import (
836 invalidate_account_oauth_tokens,
837- login_failed,
838+ invalidate_preferredemail,
839 log_login_failed,
840 log_login_succeeded,
841+ login_failed,
842+ login_succeeded,
843 set_session_oauth_token,
844 track_failed_login,
845- login_succeeded
846 )
847-from identityprovider.readonly import ReadOnlyManager
848 from identityprovider.tests import DEFAULT_USER_PASSWORD
849 from identityprovider.tests.test_auth import AuthLogTestCaseMixin
850 from identityprovider.tests.utils import SSOBaseTestCase
851@@ -304,3 +313,53 @@
852 log_login_succeeded(
853 sender, self.account, request, email='foo@bar.com',
854 authlogtype=AuthLogType.OPENID_EXISTING)
855+
856+
857+class InvalidatePreferredEmailTestCase(SSOBaseTestCase):
858+
859+ def test_signal_connected(self):
860+ registered_functions = [r[1]() for r in pre_delete.receivers]
861+ self.assertIn(invalidate_preferredemail, registered_functions)
862+
863+ def test_invalidates_on_email_address_deletion(self):
864+ account = self.factory.make_account()
865+ email = self.factory.make_email_for_account(
866+ account, status=EmailStatus.PREFERRED)
867+ self.assertEqual(account.preferredemail, email)
868+ # Avoid spurious pass if the property caching doesn't work.
869+ with self.assertNumQueries(0):
870+ self.assertEqual(account.preferredemail, email)
871+
872+ account.emailaddress_set.all().delete()
873+ with self.assertNumQueries(1):
874+ self.assertIsNone(account.preferredemail)
875+
876+ def test_leaves_unrelated_account_untouched(self):
877+ accounts = [self.factory.make_account() for _ in range(2)]
878+ emails = [
879+ self.factory.make_email_for_account(
880+ account=account, status=EmailStatus.PREFERRED)
881+ for account in accounts
882+ ]
883+ for account, email in zip(accounts, emails):
884+ self.assertEqual(account.preferredemail, email)
885+ with self.assertNumQueries(0):
886+ self.assertEqual(account.preferredemail, email)
887+
888+ accounts[0].emailaddress_set.all().delete()
889+ with self.assertNumQueries(1):
890+ self.assertIsNone(accounts[0].preferredemail)
891+ with self.assertNumQueries(0):
892+ self.assertEqual(accounts[1].preferredemail, email)
893+
894+ def test_ignores_email_without_account(self):
895+ person = self.factory.make_person()
896+ email = EmailAddress.objects.create(
897+ email=self.factory.make_email_address(),
898+ lp_person=person.id,
899+ status=EmailStatus.PREFERRED,
900+ account=None)
901+
902+ # There's no account to modify in this case; just make sure it
903+ # doesn't raise an exception.
904+ EmailAddress.objects.filter(email=email.email).delete()
905
906=== added file 'src/testing/helpers.py'
907--- src/testing/helpers.py 1970-01-01 00:00:00 +0000
908+++ src/testing/helpers.py 2018-01-23 17:52:15 +0000
909@@ -0,0 +1,66 @@
910+# Copyright 2018 Canonical Ltd. This software is licensed under the
911+# GNU Affero General Public License version 3 (see the file LICENSE).
912+
913+from __future__ import absolute_import, print_function, unicode_literals
914+
915+from operator import itemgetter
916+
917+from django.db import connection
918+from django.test.utils import CaptureQueriesContext
919+
920+
921+def format_queries(qs):
922+ return "\n\n".join(map(str, map(itemgetter('sql'), qs)))
923+
924+
925+def assert_no_extra_queries_after(modify_data, f, *args, **kwargs):
926+ """Assert function performs equal query numbers after some change.
927+
928+ Many views perform terribly because they perform extra queries
929+ for every item in some list. For example when looping over a
930+ list of users to present them, if a value is accessed through
931+ a foreign key then it may perform an extra query. If care
932+ is not taken then the view may perform terribly when
933+ operating on a large list. This is often worked around via
934+ .prefetch_related() or .select_related().
935+
936+ This function helps in testing that this doesn't happen. The arguments
937+ to this function are a callable to modify data, and a function
938+ to call before and after (*args and **kwargs are passed
939+ through to the latter function). Both calls to the function
940+ should issue the same number of queries.
941+
942+ For instance
943+
944+ self.factory.make_user()
945+ assert_no_extra_queries_after(
946+ self.factory.make_user,
947+ self.client.get, '/users'):
948+
949+ Will first request the view and record how many queries were
950+ done. It will then run the modify function, creating another user
951+ in this case, and then request the view again. If the second
952+ request doesn't issue the same number of queries then an
953+ AssertionError will be raised.
954+
955+ """
956+
957+ first = CaptureQueriesContext(connection)
958+ with first:
959+ f(*args, **kwargs)
960+ modify_data()
961+ second = CaptureQueriesContext(connection)
962+ with second:
963+ result = f(*args, **kwargs)
964+
965+ if len(first) != len(second):
966+ message = ("Different number of queries for the "
967+ "second execution. {} != {}".format(
968+ len(first), len(second)))
969+ message += "\nFirst queries:\n"
970+ message += format_queries(first.captured_queries)
971+ message += "\n\n\nSecond queries:\n"
972+ message += format_queries(second.captured_queries)
973+ raise AssertionError(message)
974+
975+ return result