Jafacak.es

Converting Django models and Wagtail pages

Published:
Hero image

Over the lifetime of a Wagtail project, many decisions get made, needs change and so do the team members. Throughout all of this, there may come a time where you look at some content and think "This shouldn't be a BlogPost, this should be a NewsPage." We will look at converting a single instance of a page from one Page type to another. Whilst this post is not going to cover migrating a whole Page type and all of it's data, the code in this post will be reusable for that purpose.

Understanding the problem

Since Wagtail pages are at their core just a Django model, a lot of the content in this post is applicable if you are trying to convert between Django models.

Risks to overcome

Possible approaches

There are a few approaches that I considered here and some development preferences that I tried to uphold where possible.

I will be going with the last approach for this post!

The model tree

Our first hurdle is working out if the 2 models that we want to convert between are "compatible". By compatible I mean that the 2 models can have data copied between them and we will have enough data to fully populate the final model.

The following diagram shows the Page types I will use to explain the different scenarios:

Models

In order from simple to complex, we have the following scenarios:

The reason that I have ordered them this way is as follows:

Converting to a parental model

This is the most simple scenario as we can be quite certain that the page instance we are converting contains ALL of the data needed. The only time this will change is if the model that we are converting from has altered one of the inherited fields.

Converting to a parent model will result in an instance that is complete, but it will also result in loss of any data that existed on fields that were defined on the original model.

Converting to a sibling model

The complexity in this scenario is increased over the parental model as siblings share fields from their parent, but then might have additional fields that we won't have data for.

Converting to a sibling model will result in an instance that is partially complete. It will have all the data from the shared parent model, but will have no data in the additional fields defined on the new model.

Converting to a distant model

The complexity in this scenario is increased again as we need to find the shared model and then move data into the specific tables for each of the differing content types.

Converting to a distant model will also result in an instance that is partially complete. Similar to the sibling scenario, except all of the fields defined after the shared ancestor will be empty.

In summary

Based on the above, we can specify that 2 models are compatible to convert between if they share an ancestor that is a Django Model. We might encounter some issues with required fields that we don't have data for. This scenario will be addressed near the end of this post.

Writing code that works for the distant model scenario should result in code that works for all three of the scenarios described above.

Now let's get to the useful part! The code:

The code!

First of all, we will set up some models to test with. The models defined below should match the diagram from above:

# models.py
from django.db import models

class Page(models.Model):
    """
    Placeholder Page model for now.
    """
    ...

class HomePage(Page):
    status = models.CharField(max_length=255, blank=True, null=True)


class BasePage(Page):
    body = models.TextField()


class BlogPage(BasePage):
    enable_comments = models.BooleanField(default=True)


class NewsPage(BasePage):
    category = models.CharField(max_length=255)

We are currently using a Django Model instead of a Wagtail Page to simplify things for now. We will Switch it out for a Wagtail Page later once we have the code working for Django Models.

Now we can define our first function. We will start by defining the method to convert an instance to a particular model:

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
):

So for example, we will call the function in the following ways for each of the scenarios we laid out previously:

news_page = NewsPage.objects.create(
    title="News Page",
    body="News Body",
    category="Some Category",
)

# Direct parent
convert(news_page, BasePage)

# Direct sibling
convert(news_page, BlogPage)

# Distant relative
convert(news_page, HomePage)

We can comment the function with the planned steps to keep track:

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
):
    # Find the first common ancestor between the model_instance class and the to_model
    # Clear out all of the tables between the model_instance class and the first common ancestor
    # Add the data back in for each of the classes from the common ancestor back down to the to_model

Finding the first common ancestor is a handy piece of code that can be separate from the rest of this function so I will write it as it's own function:

import inspect

def find_first_common_ancestor(
    from_model: Type[models.Model],
    to_model: Type[models.Model],
) -> Type[models.Model]:
    if from_model == to_model:
        return from_model

    from_mro = inspect.getmro(from_model)
    to_mro = inspect.getmro(to_model)

    for from_cls in from_mro:
        if from_cls in to_mro and not from_cls._meta.abstract:
            return from_cls

    raise ValueError("No common ancestor found")

The above code will get the MRO for the from_model and the to_model and then find the first class in the from_model mro that is also in the to_model mro.

Since both classes should be a Django Model, the ValueError at the end should never be raised. But we add it since there's no harm in having a helpful message if someone uses the function incorrectly.

We can update our convert method:

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
):
    common_ancestor = find_first_common_ancestor(
        from_model=model_instance.__class__,
        to_model=to_model,
    )
    # Clear out all of the tables between the model_instance class and the first common ancestor
    # Add the data back in for each of the classes from the common ancestor back down to the to_model

To clear out all of the tables between the model_instance class and the common_ancestor we can just loop over each of the classes in the model_instance class MRO until we reach the ancestor.

for cls in inspect.getmro(model_instance.__class__):
    if cls == common_ancestor:
        break
    # TODO: Clear out the table

To clear out the table we can use the `delete` method from the Model class and use the keep_parents arg. This will delete only the table row for the current class that we are acting on.

That means that we can load the object from the database for the current class we are on in the MRO, and then call `.delete(keep_parents=True)`.

for cls in inspect.getmro(model_instance.__class__):
    if cls == common_ancestor:
        break
    cls.objects.get(pk=model_instance.pk).delete(keep_parents=True)

Now we can update the convert method:

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
):
    common_ancestor = find_first_common_ancestor(
        from_model=model_instance.__class__,
        to_model=to_model,
    )

    # Clear all tables up to the common ancestor
    for cls in inspect.getmro(model_instance.__class__):
        if cls == common_ancestor:
            break
        cls.objects.get(pk=model_instance.pk).delete(keep_parents=True)

    # Add the data back in for each of the classes from the common ancestor back down to the to_model

Now we no longer have any data in model_instance class table, or any of the tables between the model_instance class table and the common_ancestor table.

Now if we create the data down from the common_ancestor to the to_model, we will have completed the conversion! This is easier said than done though as there are a few obstacles to navigate along the way.

We will start by getting a list of all of the classes between the common_ancestor and the model_instance class.

# Example vars:
# to_model = NewsPage
# common_ancestor = Page

# Get the to_model MRO
to_model_mro = inspect.getmro(to_model)
# Example:
# (<class 'core.models.NewsPage'>, <class 'core.models.BasePage'>, <class 'wagtail.models.Page'>, ... )

# Remove everything in the MRO after, and including, the common ancestor
to_model_mro = to_model_mro[: to_model_mro.index(common_ancestor)]
# Example:
# (<class 'core.models.NewsPage'>, <class 'core.models.BasePage'>)

# Reverse the MRO so we create the tables in the correct order
to_model_mro = to_model_mro[::-1]
# Example:
# (<class 'core.models.BasePage'>, <class 'core.models.NewsPage'>)

The reason that we reverse the order is because we need to know the parent table to know the name of the `ptr_id` field.

Given the example, we want to do the following:

The `ptr_id` fields is named using the lowercase of the parent model class name. This can be accessed using `Model._meta.model_name`. Logically this means that we can just start by using the model_name from the common ancestor, then looping over each of the classes in the MRO, we can use the previous class as the new parent class and repeat:

previous_mro_class = common_ancestor
for cls in to_model_mro:
    # Skip abstract classes
    if cls._meta.abstract:
        continue

    ptr_field_name = f"{previous_mro_class._meta.model_name}_ptr_id"
    
    # TODO: perform operation to create the table for the current `cls`
    
    previous_mro_class = cls

    if cls == to_model:
        break

Now we want to create the table data for the class. To do this we can create an instance of the class like so:

cls(page_ptr_id=model_instance.pk)

But since we know the `ptr` field name will change each time, we will want to build a kwargs dict that we can unpack to pass into the class intialisation, like so:

new_cls_kwargs = {
    ptr_field_name: model_instance.pk
}
new_model_instance = cls(**new_cls_kwargs)
new_model_instance.save()

While the above code will work, it results in us having no data on the model as it calls `save` on a class that doesn't hold values for the fields on the previous models. To fix this we can update the new_model_instance to contain the field data we know about:

new_cls_kwargs = {
    ptr_field_name: model_instance.pk
}
new_model_instance = cls(**new_cls_kwargs)
new_model_instance.__dict__.update(model_instance.__dict__)
new_model_instance.save()

Great! we have something that works 😄

Time to update the convert method with all of this new code:

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
):
    common_ancestor = find_first_common_ancestor(
        from_model=model_instance.__class__,
        to_model=to_model,
    )

    # Clear all tables up to the common ancestor
    for cls in inspect.getmro(model_instance.__class__):
        if cls == common_ancestor:
            break
        if cls._meta.abstract:
            continue
        cls.objects.get(pk=model_instance.pk).delete(keep_parents=True)

    # Get the to_model MRO
    to_model_mro = inspect.getmro(to_model)
    # Remove everything in the MRO after, and including, the common ancestor
    to_model_mro = to_model_mro[:to_model_mro.index(common_ancestor)]
    # Reverse the MRO so we create the tables in the correct order
    to_model_mro = to_model_mro[::-1]

    previous_mro_class = common_ancestor
    for cls in to_model_mro:
        # Skip abstract classes
        if cls._meta.abstract:
            continue

        ptr_field_name = f"{previous_mro_class._meta.model_name}_ptr_id"
    
        new_cls_kwargs = {
            ptr_field_name: model_instance.pk
        }
        new_model_instance = cls(**new_cls_kwargs)
        new_model_instance.__dict__.update(model_instance.__dict__)
        new_model_instance.save()
    
        previous_mro_class = cls

        if cls == to_model:
            break

What about models that have required fields?

A great question, we will define a model that meets this scenario first so that we can address this issue:

class RequiredFieldPage(BasePage):
    important_data = models.BooleanField()

If we try to run the convert method on a NewsPage to RequiredFieldPage we hit some database issues because we aren't providing the important_data column with data, and it's a "NOT NULL" field.

To overcome this, I decided that the best approach would be the ability to define some default kwargs for each class. This will be a dict of dicts where the key in the first dict is the class that we want to pass the kwargs for and the keys of the second dict are the field names and the values in the second dict are the field values. For example:

{
    RequiredFieldPage: {
        "important_data": True,
    },
}

This format is great as it means that we can provide default values for required fields, but also for other fields on a per model class basis. Example:

{
    BasePage: {
        "body": "This is my new body content",
    },
    RequiredFieldPage: {
        "important_data": True,
    },
}

We can pass this into the convert method if we update the function signature like so:

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
    default_data: dict | None = None,
):

We could type the default_data arg, but I will do that later.

Now we just have to set a default empty dict if the default_data is None, and insert the correct dicts into the class kwargs when we are adding the data back in.

def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
    default_data: dict | None = None,
):
    if default_data is None:
        default_data = {}

    common_ancestor = find_first_common_ancestor(model_instance.__class__, to_model)

    # Clear all tables up to the common ancestor
    for cls in inspect.getmro(model_instance.__class__):
        if cls == common_ancestor:
            break
        if cls._meta.abstract:
            continue

        cls.objects.get(pk=model_instance.pk).delete(keep_parents=True)

    # Create all of the tables down to the target model
    to_model_mro = inspect.getmro(to_model)
    # Remove everything in the MRO after, and including, the common ancestor
    to_model_mro = to_model_mro[: to_model_mro.index(common_ancestor)]
    # Reverse the MRO so we create the tables in the correct order
    to_model_mro = to_model_mro[::-1]

    # Update the `content_type` value
    common_ancestor_instance = common_ancestor.objects.get(pk=model_instance.pk)
    common_ancestor_instance.content_type = ContentType.objects.get_for_model(to_model)
    common_ancestor_instance.save(update_fields=["content_type_id"])

    previous_mro_class = common_ancestor
    for cls in to_model_mro:
        # Skip abstract classes
        if cls._meta.abstract:
            continue

        ptr_field_name = f"{previous_mro_class._meta.model_name}_ptr_id"
        new_cls_fields = [f.name for f in cls._meta.get_fields()]
        new_cls_kwargs = model_instance.__dict__.copy()
        new_cls_kwargs[ptr_field_name] = model_instance.id
        if cls in default_data:
            new_cls_kwargs.update(default_data[cls])

        # Get a list of fields that are not in the new class
        fields_to_remove = set(new_cls_kwargs.keys()) - set(new_cls_fields)
        for field in fields_to_remove:
            del new_cls_kwargs[field]

        model_instance = cls(**new_cls_kwargs)
        model_instance.__dict__.update(model_instance.__dict__)
        model_instance.save()

        previous_mro_class = cls

        if cls == to_model:
            break

In the above, you will also see there are quite a few changes:

What happens if the code fails part way through the operation?

Another great question!

If we leave the function as it currently stands, we are exposed to leaving the database in a horrendous state. To fix this we can use a transaction. This can be done by decorating the function with @transaction.atomic.

The code for converting between Django models!

Below is the output for converting between Django models:

import inspect
from typing import Type

from django.db import models, transaction


def find_first_common_ancestor(
    from_model: Type[models.Model],
    to_model: Type[models.Model],
) -> Type[models.Model]:
    if from_model == to_model:
        return from_model

    from_mro = inspect.getmro(from_model)
    to_mro = inspect.getmro(to_model)

    for from_cls in from_mro:
        if from_cls in to_mro:
            return from_cls

    raise ValueError("No common ancestor found")


@transaction.atomic
def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
    default_data: dict | None = None,
):
    if default_data is None:
        default_data = {}

    common_ancestor = find_first_common_ancestor(model_instance.__class__, to_model)

    # Clear all tables up to the common ancestor
    for cls in inspect.getmro(model_instance.__class__):
        if cls == common_ancestor:
            break

        cls.objects.get(pk=model_instance.pk).delete(keep_parents=True)

    # Create all of the tables down to the target model
    to_model_mro = inspect.getmro(to_model)
    # Remove everything in the MRO after, and including, the common ancestor
    to_model_mro = to_model_mro[: to_model_mro.index(common_ancestor)]
    # Reverse the MRO so we create the tables in the correct order
    to_model_mro = to_model_mro[::-1]

    # Update the `content_type` value
    common_ancestor_instance = common_ancestor.objects.get(pk=model_instance.pk)
    common_ancestor_instance.content_type = ContentType.objects.get_for_model(to_model)
    common_ancestor_instance.save(update_fields=["content_type_id"])

    previous_mro_class = common_ancestor
    for cls in to_model_mro:
        # Skip abstract classes
        if cls._meta.abstract:
            continue

        new_cls_fields = [f.name for f in cls._meta.get_fields()]
        new_cls_kwargs = model_instance.__dict__.copy()
        new_cls_kwargs[f"{previous_mro_class._meta.model_name}_ptr_id"] = (
            model_instance.id
        )
        if cls in default_data:
            new_cls_kwargs.update(default_data[cls])

        # Get a list of fields that are not in the new class
        fields_to_remove = set(new_cls_kwargs.keys()) - set(new_cls_fields)
        for field in fields_to_remove:
            del new_cls_kwargs[field]

        model_instance = cls(**new_cls_kwargs)
        model_instance.__dict__.update(model_instance.__dict__)
        model_instance.save()

        previous_mro_class = cls

        if cls == to_model:
            break

What about Wagtail pages?

If we switch out out testing base class from the custom Page class to the Wagtail Page class, we will see that we encounter some issues.
The first issue we hit is to do with how Wagtail manages page deletion:

TypeError
TreeQuerySet.delete() got an unexpected keyword argument 'keep_parents'

Looking at the Wagtail code, it doesn't seem to respect the `keep_parents` behaviour. I'm not sure if this is intentional or not.

I will continue with the aim to try and get the changes I make suggested as changes back into Wagtail.

There are 2 changes below:

class FixedDeletePageAction(DeletePageAction):
    def _delete_page(self, page, *args, **kwargs):
        from wagtail.models import Page

        keep_parents = kwargs.pop("keep_parents", False)

        # Ensure that deletion always happens on an instance of Page, not a specific subclass. This
        # works around a bug in treebeard <= 3.0 where calling SpecificPage.delete() fails to delete
        # child pages that are not instances of SpecificPage
        if type(page) is Page:
            for child in page.get_descendants().specific().iterator():
                self.log_deletion(child)
            self.log_deletion(page.specific)

            # this is a Page instance, so carry on as we were
            return super(Page, page).delete(*args, **kwargs)
        elif not keep_parents:
            # retrieve an actual Page instance and delete that instead of page
            return DeletePageAction(
                Page.objects.get(id=page.id), user=self.user
            ).execute(*args, **kwargs)

# Add the following method to models that directly inherit from Wagtail Page:
def delete(self, *args, **kwargs):
    user = kwargs.pop("user", None)
    return FixedDeletePageAction(self, user=user).execute(*args, **kwargs)

With that change, we now hit a new issue... Creating the tables from the common_ancestor down to the to_model stops working because it thinks that we are creating a new page with a ID that is already in use, and a slug that isn't unique, etc.

I think we've got to reach for some SQL now. I've been putting it off for a while as I really didn't want to do this, but I can't think of another way.

We're going to be changing code from the line previous_mro_class = common_ancestor down. Add the following import from django.db import connections and set db_connection = connections["default"]

previous_mro_class = common_ancestor
for cls in to_model_mro:
    # Skip abstract classes
    if cls._meta.abstract:
        continue

    new_cls_fields = [f.name for f in cls._meta.get_fields()]
    new_cls_kwargs = model_instance.__dict__.copy()

    # Get a list of fields that are not in the new class
    fields_to_remove = set(new_cls_kwargs.keys()) - set(new_cls_fields)
    # Clear the base Model fields.
    if issubclass(cls, common_ancestor):
        fields_to_remove.update(
            set(
                [
                    f.name
                    for f in common_ancestor._meta.get_fields()
                    if f.name in new_cls_kwargs
                ]
            )
        )

    for field in fields_to_remove:
        del new_cls_kwargs[field]

    new_cls_kwargs[f"{previous_mro_class._meta.model_name}_ptr_id"] = (
        model_instance.id
    )

    if cls in default_data:
        new_cls_kwargs.update(default_data[cls])

    # Add the row to the new table.
    with db_connection.cursor() as cursor:
        table_name = cls._meta.db_table
        fields = ",".join(new_cls_kwargs.keys())
        values = ",".join([f"'{str(v)}'" for v in new_cls_kwargs.values()])
        cursor.execute(
            f"""
            INSERT INTO {table_name} ({fields})
            VALUES ({values});
            """
        )

    previous_mro_class = cls

    if cls == to_model:
        break

When testing this in one direction it works fine, but converting back it seems we hit an error:
duplicate key value violates unique constraint
DETAIL: Key (contentpage_ptr_id)=(50) already exists.

Looking at the data in the tables, it seems the Wagtail override didn't work. So we will revert those changes and replace it with SQL delete statements.

We will need a new helper method that gets the next model in the mro that isn't abstract:

def get_next_model_from_mro(
    model: Type[models.Model],
) -> Type[models.Model]:
    for cls in inspect.getmro(model):
        if cls == model:
            continue
        if not cls._meta.abstract:
            return cls
    raise ValueError("No concrete class found")

And now we can update the "Clear all tables up to the common ancestor" section with the following:

# Clear all tables up to the common ancestor
for cls in inspect.getmro(model_instance.__class__):
    if cls == common_ancestor:
        break
    if cls._meta.abstract:
        continue
    next_mro_class = get_next_model_from_mro(cls)
    ptr_field_name = f"{next_mro_class._meta.model_name}_ptr_id"
    # Clear the row from the `networks_network` table
    with db_connection.cursor() as cursor:
        cursor.execute(
            f"""
            DELETE FROM {cls._meta.db_table}
            WHERE {ptr_field_name}={f"'{model_instance.pk}'"};
            """,
        )

And that should be it!

The FINAL code:

import inspect
from typing import Type

from django.contrib.contenttypes.models import ContentType
from django.db import connections, models, transaction


def find_first_common_ancestor(
    from_model: Type[models.Model],
    to_model: Type[models.Model],
) -> Type[models.Model]:
    if from_model == to_model:
        return from_model

    from_mro = inspect.getmro(from_model)
    to_mro = inspect.getmro(to_model)

    for from_cls in from_mro:
        if from_cls in to_mro and not from_cls._meta.abstract:
            return from_cls

    raise ValueError("No common ancestor found")


def get_next_model_from_mro(
    model: Type[models.Model],
) -> Type[models.Model]:
    for cls in inspect.getmro(model):
        if cls == model:
            continue
        if not cls._meta.abstract:
            return cls
    raise ValueError("No concrete class found")


def build_values(values: list) -> list[str]:
    output_values = []

    for v in values:
        if isinstance(v, bool):
            output_values.append(str(v))
        else:
            output_values.append(f"'{str(v)}'")

    return output_values


@transaction.atomic
def convert(
    model_instance: models.Model,
    to_model: Type[models.Model],
    default_data: dict | None = None,
):
    db_connection = connections["default"]

    if default_data is None:
        default_data = {}

    common_ancestor = find_first_common_ancestor(model_instance.__class__, to_model)

    # Clear all tables up to the common ancestor
    for cls in inspect.getmro(model_instance.__class__):
        if cls == common_ancestor:
            break
        if cls._meta.abstract:
            continue
        next_mro_class = get_next_model_from_mro(cls)
        ptr_field_name = f"{next_mro_class._meta.model_name}_ptr_id"
        # Clear the row from the `networks_network` table
        with db_connection.cursor() as cursor:
            cursor.execute(
                f"""
                DELETE FROM {cls._meta.db_table}
                WHERE {ptr_field_name}={f"'{model_instance.pk}'"};
                """,
            )

    # Create all of the tables down to the target model
    to_model_mro = inspect.getmro(to_model)
    # Remove everything in the MRO after, and including, the common ancestor
    to_model_mro = to_model_mro[: to_model_mro.index(common_ancestor)]
    # Reverse the MRO so we create the tables in the correct order
    to_model_mro = to_model_mro[::-1]

    # Update the `content_type` value
    if hasattr(common_ancestor, "content_type"):
        common_ancestor_instance = common_ancestor.objects.get(pk=model_instance.pk)
        common_ancestor_instance.content_type = ContentType.objects.get_for_model(
            to_model
        )
        common_ancestor_instance.save(update_fields=["content_type_id"])

    previous_mro_class = common_ancestor
    for cls in to_model_mro:
        # Skip abstract classes
        if cls._meta.abstract:
            continue

        new_cls_fields = [f.name for f in cls._meta.get_fields()]
        new_cls_kwargs = model_instance.__dict__.copy()

        # Get a list of fields that are not in the new class
        fields_to_remove = set(new_cls_kwargs.keys()) - set(new_cls_fields)
        # Clear the base Model fields.
        if issubclass(cls, common_ancestor):
            fields_to_remove.update(
                set(
                    [
                        f.name
                        for f in common_ancestor._meta.get_fields()
                        if f.name in new_cls_kwargs
                    ]
                )
            )

        for field in fields_to_remove:
            del new_cls_kwargs[field]

        new_cls_kwargs[f"{previous_mro_class._meta.model_name}_ptr_id"] = (
            model_instance.id
        )

        if cls in default_data:
            new_cls_kwargs.update(default_data[cls])

        # Add the row to the new table.
        with db_connection.cursor() as cursor:
            table_name = cls._meta.db_table
            fields = ",".join(new_cls_kwargs.keys())
            values = ",".join([v for v in build_values(new_cls_kwargs.values())])
            print(f"Adding row to {table_name}", fields, values)
            cursor.execute(
                f"""
                INSERT INTO {table_name} ({fields})
                VALUES ({values});
                """
            )

        previous_mro_class = cls

        if cls == to_model:
            break

And to use it, we can refer back to our scenarios at the start with a little tweak:

news_page = NewsPage.objects.create(
    title="News Page",
    body="News Body",
    category="Some Category",
)

# Direct parent
convert(news_page, BasePage)

# Direct sibling
convert(news_page, BlogPage, default_data={
    BlogPage: {
        "enable_comments": False,
    },
})

# Distant relative
convert(news_page, HomePage, default_data={
    HomePage: {
        "status": "An interesting status message!"
    },
})