Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 68 additions & 23 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import hashlib
from datetime import date, datetime
from pathlib import Path
from typing import Self, Union, List
from typing import Any, Self, Type, Union, List
import statistics

from django.utils import timezone
Expand Down Expand Up @@ -426,33 +426,78 @@ def update_dependencies(self) -> Union[str, None]:
if (error_msg := dependency.update()) not in [None, "libraryHasNoUpdate"]:
return error_msg

def update_threats(self):
for threat in self.threats:
normalized_urn = threat["urn"].lower()
Threat.objects.update_or_create(
urn=normalized_urn,
defaults=threat,
create_defaults={
**self.referential_object_dict,
**self.i18n_object_dict,
**threat,
"library": self.old_library,
},
def _synchronize_related_objects(
self,
*,
model_class: Type[models.Model],
incoming_data: list[dict[str, Any]],
unique_field: str = "urn",
) -> list:
"""Generic and database-agnostic method to synchronize related objects."""
if not incoming_data:
model_class.objects.filter(library=self.old_library).update(library=None)
return []

incoming_ids = {item[unique_field].lower() for item in incoming_data}

model_class.objects.filter(library=self.old_library).exclude(
**{f"{unique_field}__in": incoming_ids}
).update(library=None)

existing_obj_map = {
getattr(obj, unique_field): obj
for obj in model_class.objects.filter(
**{f"{unique_field}__in": incoming_ids}
)
}

def update_reference_controls(self):
for reference_control in self.reference_controls:
normalized_urn = reference_control["urn"].lower()
ReferenceControl.objects.update_or_create(
urn=normalized_urn,
defaults=reference_control,
create_defaults={
to_create, to_update = [], []
update_fields = set()

for item_data in incoming_data:
normalized_id = item_data[unique_field].lower()
update_fields.update(item_data.keys())

if normalized_id in existing_obj_map:
instance = existing_obj_map[normalized_id]
for key, value in item_data.items():
setattr(instance, key, value)
instance.library = self.old_library
to_update.append(instance)
else:
create_data = {
**self.referential_object_dict,
**self.i18n_object_dict,
**reference_control,
**item_data,
unique_field: normalized_id,
"library": self.old_library,
},
)
}
to_create.append(model_class(**create_data))

created_objects = (
list(model_class.objects.bulk_create(to_create)) if to_create else []
)
if to_update:
update_fields.add("library")
model_class.objects.bulk_update(to_update, list(update_fields))

Comment on lines +437 to +483
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Generic sync helper can silently hijack rows & violate uniqueness.

Key issues spotted:

  1. existing_obj_map is built with

    model_class.objects.filter(**{f"{unique_field}__in": incoming_ids})

    No library filter ⇒ objects that belong to another library are pulled in and later

    instance.library = self.old_library

    re-assigned.
    Result: cross-library objects get “stolen”, breaking referential integrity.

  2. update_fields contains every key plus library plus unique_field (e.g. urn).
    Updating primary identifiers is useless and risks hitting unique constraints; keep them immutable.

  3. Union of arbitrary item_data.keys() goes straight into bulk_update.
    A rogue field in YAML ⇒ runtime FieldError.

Suggested minimal fix:

@@
-existing_obj_map = {
-    getattr(obj, unique_field): obj
-    for obj in model_class.objects.filter(
-        **{f"{unique_field}__in": incoming_ids}
-    )
-}
+existing_obj_map = {
+    getattr(obj, unique_field): obj
+    for obj in model_class.objects.filter(
+        **{
+            f"{unique_field}__in": incoming_ids,
+            "library": self.old_library,   # prevent cross-library hijack
+        }
+    )
+}
@@
-        update_fields.update(item_data.keys())
+        update_fields.update(
+            k for k in item_data.keys() if k != unique_field  # keep id immutable
+        )
@@
-    if to_update:
-        update_fields.add("library")
-        model_class.objects.bulk_update(to_update, list(update_fields))
+    if to_update:
+        model_class.objects.bulk_update(
+            to_update,
+            fields=list(update_fields | {"library"}),
+        )

(You may also want to wrap the whole sync in a transaction to keep old ↔ new state consistent.)

🤖 Prompt for AI Agents
In backend/core/models.py lines 437 to 483, the sync helper lacks a library
filter when building existing_obj_map, causing cross-library objects to be
incorrectly reassigned. Fix this by adding a filter for library=self.old_library
in the queryset used for existing_obj_map. Also, exclude the unique_field from
update_fields to avoid updating primary identifiers, and sanitize update_fields
to include only valid model fields to prevent runtime FieldErrors. Optionally,
wrap the entire sync operation in a transaction to maintain consistency.

return created_objects

def update_reference_controls(self) -> list["ReferenceControl"]:
"""Synchronizes reference controls by delegating to the generic helper."""
return self._synchronize_related_objects(
model_class=ReferenceControl,
incoming_data=self.reference_controls,
unique_field="urn",
)

def update_threats(self) -> list["Threat"]:
"""Synchronizes threats by delegating to the generic helper."""
return self._synchronize_related_objects(
model_class=Threat,
incoming_data=self.threats,
unique_field="urn",
)

def update_frameworks(self):
for new_framework in self.new_frameworks:
Expand Down
210 changes: 210 additions & 0 deletions backend/core/tests/test_library_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import pytest

from core.models import (
LibraryUpdater,
LoadedLibrary,
StoredLibrary,
ReferenceControl,
Threat,
)


@pytest.fixture
def old_library():
"""Fixture to create a real LoadedLibrary instance in the test DB."""
return LoadedLibrary.objects.create(
urn="urn:lib:old:1",
locale="en",
default_locale=True,
provider="TestProvider",
version=1,
)


@pytest.fixture
def new_library_content():
"""Fixture for the content of the new StoredLibrary."""
return {
"provider": "NewProvider",
"dependencies": ["urn:lib:dep:1"],
"reference_controls": [
{"urn": "urn:rc:1", "name": "Control One New"},
{"urn": "urn:rc:2", "name": "Control Two Updated"},
],
"threats": [
{"urn": "urn:threat:100", "name": "New Threat"},
{"urn": "urn:threat:200", "name": "Threat Two Updated"},
],
}


@pytest.fixture
def new_library(new_library_content):
"""Fixture to create a real StoredLibrary instance in the test DB."""
return StoredLibrary.objects.create(
urn="urn:lib:new:1",
content=new_library_content,
dependencies=new_library_content["dependencies"],
provider=new_library_content["provider"],
version=1,
)


@pytest.mark.django_db
class TestLibraryUpdater:
"""Test suite for the object synchronization logic in LibraryUpdater."""

def test_synchronize_full_mixed_mode_for_controls(self, old_library, new_library):
"""Tests create, update, and unlink operations for ReferenceControls in one go."""
ReferenceControl.objects.create(
urn="urn:rc:2",
name="Control Two Original",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)
ReferenceControl.objects.create(
urn="urn:rc:3",
name="Control Three To Unlink",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)

updater = LibraryUpdater(old_library, new_library)
created_objects = updater.update_reference_controls()

assert ReferenceControl.objects.count() == 3
assert ReferenceControl.objects.get(urn="urn:rc:1").name == "Control One New"
assert (
ReferenceControl.objects.get(urn="urn:rc:2").name == "Control Two Updated"
)
assert ReferenceControl.objects.get(urn="urn:rc:3").library is None
assert len(created_objects) == 1 and created_objects[0].urn == "urn:rc:1"

def test_synchronize_threats(self, old_library, new_library):
"""Tests create, update, and unlink operations for Threats in one go."""
# Arrange
Threat.objects.create(
urn="urn:threat:200",
name="Threat Two Original",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)
Threat.objects.create(
urn="urn:threat:300",
name="Threat Three To Unlink",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)

updater = LibraryUpdater(old_library, new_library)
created_objects = updater.update_threats()

assert Threat.objects.count() == 3
assert Threat.objects.get(urn="urn:threat:100").name == "New Threat"
assert Threat.objects.get(urn="urn:threat:200").name == "Threat Two Updated"
assert Threat.objects.get(urn="urn:threat:300").library is None
assert len(created_objects) == 1 and created_objects[0].urn == "urn:threat:100"

def test_synchronize_only_creates_new_objects(self, old_library, new_library):
"""Tests that all incoming objects are created when the DB is empty."""
assert ReferenceControl.objects.count() == 0

updater = LibraryUpdater(old_library, new_library)
created_objects = updater.update_reference_controls()

assert ReferenceControl.objects.count() == 2
assert len(created_objects) == 2
assert {obj.urn for obj in created_objects} == {"urn:rc:1", "urn:rc:2"}

def test_synchronize_only_updates_existing_objects(self, old_library, new_library):
"""Tests that existing objects are updated and no new ones are created."""
ReferenceControl.objects.create(
urn="urn:rc:1",
name="Original Name 1",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)
ReferenceControl.objects.create(
urn="urn:rc:2",
name="Original Name 2",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)
assert ReferenceControl.objects.count() == 2

updater = LibraryUpdater(old_library, new_library)
created_objects = updater.update_reference_controls()

assert ReferenceControl.objects.count() == 2
assert ReferenceControl.objects.get(urn="urn:rc:1").name == "Control One New"
assert (
ReferenceControl.objects.get(urn="urn:rc:2").name == "Control Two Updated"
)
assert len(created_objects) == 0

def test_synchronize_only_unlinks_objects(self, old_library, new_library):
"""Tests that all existing objects are unlinked when incoming data is empty."""
new_library.content["reference_controls"] = []
new_library.save()
ReferenceControl.objects.create(
urn="urn:rc:to_unlink",
name="Old Control",
library=old_library,
locale="en",
default_locale=True,
provider="TestProvider",
)
assert ReferenceControl.objects.count() == 1

updater = LibraryUpdater(old_library, new_library)
created_objects = updater.update_reference_controls()

assert ReferenceControl.objects.count() == 1
assert ReferenceControl.objects.get(urn="urn:rc:to_unlink").library is None
assert created_objects == []

def test_synchronize_no_changes_needed(self, old_library, new_library):
"""Tests that nothing changes when DB state matches incoming data."""
rc1_data = new_library.content["reference_controls"][0]
rc2_data = new_library.content["reference_controls"][1]
ReferenceControl.objects.create(
urn=rc1_data["urn"],
name=rc1_data["name"],
library=old_library,
locale="en",
default_locale=True,
provider="NewProvider",
is_published=True,
)
ReferenceControl.objects.create(
urn=rc2_data["urn"],
name=rc2_data["name"],
library=old_library,
locale="en",
default_locale=True,
provider="NewProvider",
is_published=True,
)
assert ReferenceControl.objects.count() == 2

updater = LibraryUpdater(old_library, new_library)
created_objects = updater.update_reference_controls()

assert ReferenceControl.objects.count() == 2
assert len(created_objects) == 0
# Verify data is unchanged
rc1_db = ReferenceControl.objects.get(urn=rc1_data["urn"])
assert rc1_db.name == rc1_data["name"]
assert rc1_db.library == old_library
Loading