diff --git a/admin/management/views.py b/admin/management/views.py index f2052822f37..4ec96a4ea3a 100644 --- a/admin/management/views.py +++ b/admin/management/views.py @@ -179,7 +179,17 @@ def post(self, request): class SyncNotificationTemplates(ManagementCommandPermissionView): def post(self, request): - populate_notification_types() + run_type = request.POST.get('run_type') + if run_type == 'restore_one': + template_name = request.POST.get('template_name') + if not template_name: + messages.error(request, 'A template name must be specified when restoring one template. Check your inputs and try again') + return redirect(reverse('management:commands')) + populate_notification_types(restore_one=template_name) + elif run_type == 'restore_all': + populate_notification_types(restore_all=True) + else: + populate_notification_types() messages.success(request, 'Notification templates have been successfully synced.') return redirect(reverse('management:commands')) diff --git a/admin/templates/management/commands.html b/admin/templates/management/commands.html index edf242abfdd..87ff147f919 100644 --- a/admin/templates/management/commands.html +++ b/admin/templates/management/commands.html @@ -160,6 +160,16 @@

Sync Notification Templates

{% csrf_token %} + + +
+ +
+ diff --git a/osf/management/commands/populate_notification_types.py b/osf/management/commands/populate_notification_types.py index 302c1069a17..5f0145f08a9 100644 --- a/osf/management/commands/populate_notification_types.py +++ b/osf/management/commands/populate_notification_types.py @@ -1,6 +1,5 @@ import sys import yaml -from django.apps import apps from waffle import switch_is_active from osf import features @@ -19,7 +18,7 @@ 'email_transactional': 'instantly', } -def populate_notification_types(*args, **kwargs): +def populate_notification_types(*args, restore_one=None, restore_all=False, **kwargs): if kwargs.get('sender'): # exists when called as a post_migrate signal if not switch_is_active(features.POPULATE_NOTIFICATION_TYPES): if 'pytest' not in sys.modules: @@ -28,64 +27,89 @@ def populate_notification_types(*args, **kwargs): logger.info('Populating notification types...') from django.contrib.contenttypes.models import ContentType from osf.models.notification_type import NotificationType + try: with open(settings.NOTIFICATION_TYPES_YAML) as stream: notification_types = yaml.safe_load(stream) - for notification_type in notification_types['notification_types']: - notification_type.pop('__docs__', None) - notification_type.pop('tests', None) - object_content_type_model_name = notification_type.pop('object_content_type_model_name') + + notification_types_dict = { + nt['name']: nt for nt in notification_types['notification_types'] + } + + all_names = set(notification_types_dict.keys()) + existing_names = set( + NotificationType.objects.values_list('name', flat=True) + ) + + if restore_one: + if restore_one not in notification_types_dict: + raise ValueError(f'Notification type "{restore_one}" not found in YAML') + names_to_process = {restore_one} + + elif restore_all: + names_to_process = all_names + + else: + names_to_process = all_names - existing_names + + logger.info(f'Processing {len(names_to_process)} notification types') + + for name in names_to_process: + raw_nt = notification_types_dict[name].copy() + + raw_nt.pop('__docs__', None) + raw_nt.pop('tests', None) + + object_content_type_model_name = raw_nt.pop('object_content_type_model_name') if object_content_type_model_name == 'desk': content_type = None - elif object_content_type_model_name == 'osfuser': - OSFUser = apps.get_model('osf', 'OSFUser') - content_type = ContentType.objects.get_for_model(OSFUser) - elif object_content_type_model_name == 'preprint': - Preprint = apps.get_model('osf', 'Preprint') - content_type = ContentType.objects.get_for_model(Preprint) - elif object_content_type_model_name == 'collectionsubmission': - CollectionSubmission = apps.get_model('osf', 'CollectionSubmission') - content_type = ContentType.objects.get_for_model(CollectionSubmission) - elif object_content_type_model_name == 'abstractprovider': - AbstractProvider = apps.get_model('osf', 'abstractprovider') - content_type = ContentType.objects.get_for_model(AbstractProvider) - elif object_content_type_model_name == 'osfuser': - OSFUser = apps.get_model('osf', 'OSFUser') - content_type = ContentType.objects.get_for_model(OSFUser) - elif object_content_type_model_name == 'draftregistration': - DraftRegistration = apps.get_model('osf', 'DraftRegistration') - content_type = ContentType.objects.get_for_model(DraftRegistration) else: try: - content_type = ContentType.objects.get( - app_label='osf', - model=object_content_type_model_name - ) + content_type = ContentType.objects.get_by_natural_key(app_label='osf', model=object_content_type_model_name) except ContentType.DoesNotExist: raise ValueError(f'No content type for osf.{object_content_type_model_name}') - template_path = notification_type.pop('template') + template_path = raw_nt.pop('template') + template = None + if template_path: with open(template_path) as stream: template = stream.read() nt, _ = NotificationType.objects.update_or_create( - name=notification_type['name'], - defaults=notification_type, + name=name, + defaults=raw_nt, ) + nt.object_content_type = content_type - if not nt.template or settings.DEV_MODE: + if template: nt.template = template + nt.save() + except ProgrammingError: logger.info('Notification types failed potential side effect of reverse migration') logger.info('Finished populating notification types.') - class Command(BaseCommand): - help = 'Population notification types.' + help = 'Populate notification types.' + + def add_arguments(self, parser): + parser.add_argument( + '--restore-all', + action='store_true', + help='Restore all templates from files' + ) + parser.add_argument( + '--restore', + type=str, + help='Restore specific template by name' + ) def handle(self, *args, **options): with transaction.atomic(): - populate_notification_types(args, options) + populate_notification_types( + restore_all=options['restore_all'], + restore_one=options['restore'] + )