コンテンツへスキップ

データベースのテスト

情報

これらのドキュメントは更新されようとしています。🎉

現在のバージョンはPydantic v1とSQLAlchemy 2.0未満のバージョンを想定しています。

新しいドキュメントにはPydantic v2が含まれ、SQLModel(これもSQLAlchemyに基づいています)がPydantic v2も使用するように更新されたら使用します。

オーバーライドによる依存関係のテストの同じ依存関係オーバーライドを使用して、テスト用のデータベースを変更できます。

テスト用に異なるデータベースをセットアップしたり、テスト後にデータをロールバックしたり、テストデータで事前に埋めたりすることができます。

主な考え方は、前の章で見たものとまったく同じです。

SQLアプリのテストを追加

SQL(リレーショナル)データベースの例を更新して、テストデータベースを使用します。

すべてのアプリコードは同じです。前の章に戻って確認できます。

ここで変更されるのは新しいテストファイルだけです。

通常の依存関係`get_db()`はデータベースセッションを返します。

テストでは、依存関係オーバーライドを使用して、通常使用されるものとは別のカスタムデータベースセッションを返すことができます。

この例では、テスト専用のテンポラリデータベースを作成します。

ファイル構造

`sql_app/tests/test_sql_app.py`に新しいファイルを作成します。

新しいファイル構造は次のようになります。

.
└── sql_app
    ├── __init__.py
    ├── crud.py
    ├── database.py
    ├── main.py
    ├── models.py
    ├── schemas.py
    └── tests
        ├── __init__.py
        └── test_sql_app.py

新しいデータベースセッションを作成

まず、新しいデータベースを使用して新しいデータベースセッションを作成します。

ローカルファイル`sql_app.db`の代わりに、テスト中に保持されるインメモリデータベースを使用します。

しかし、セッションコードの残りはほぼ同じで、コピーするだけです。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

ヒント

そのコードの重複を減らすには、関数を置いて`database.py`と`tests/test_sql_app.py`の両方から使用することができます。

シンプルさを保ち、テストコードに焦点を当てるために、コピーしています。

データベースを作成

新しいファイルで新しいデータベースを使用するようになったので、次のようにデータベースを作成する必要があります。

Base.metadata.create_all(bind=engine)

これは通常`main.py`で呼び出されますが、`main.py`の行はデータベースファイル`sql_app.db`を使用しており、テスト用に`test.db`を作成する必要があります。

そこで、新しいファイルと共にその行を追加します。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

依存関係のオーバーライド

次に、依存関係のオーバーライドを作成し、アプリのオーバーライドに追加します。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

ヒント

override_get_db() のコードは get_db() とほぼ同じですが、override_get_db() では、テストデータベース用に TestingSessionLocal を使用します。

アプリのテスト

その後、通常通りアプリをテストできます。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

そして、テスト中にデータベースに行ったすべての変更は、メインの sql_app.db ではなく、test.db データベースに反映されます。