Add some basic tests to AgentDB. (#2666)
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -107,7 +107,7 @@ jobs:
|
||||
env:
|
||||
ENABLE_OPENAI: "true"
|
||||
OPENAI_API_KEY: "sk-dummy"
|
||||
run: poetry run pytest tests
|
||||
run: poetry run pytest
|
||||
fe-lint-build:
|
||||
name: Frontend Lint and Build
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
23
poetry.lock
generated
23
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "about-time"
|
||||
@@ -250,6 +250,25 @@ files = [
|
||||
[package.dependencies]
|
||||
frozenlist = ">=1.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "aiosqlite"
|
||||
version = "0.21.0"
|
||||
description = "asyncio bridge to the standard sqlite3 module"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"},
|
||||
{file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing_extensions = ">=4.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"]
|
||||
docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "alembic"
|
||||
version = "1.15.2"
|
||||
@@ -7739,4 +7758,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.11,<3.14"
|
||||
content-hash = "36792c7985a25a7b601fbd12db04df45508749c81090acfd2860c28884bf0e9a"
|
||||
content-hash = "4665c6fa560799864feeb4eb1c86dec10957a21373add46d8b531532cec616f3"
|
||||
|
||||
@@ -99,6 +99,7 @@ build = "^1.2.2.post1"
|
||||
pandas = "^2.2.3"
|
||||
pre-commit = "^4.2.0"
|
||||
ruff = "^0.11.12"
|
||||
aiosqlite = "^0.21.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
||||
52
skyvern/forge/sdk/db/client_test.py
Normal file
52
skyvern/forge/sdk/db/client_test.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.db.models import Base
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_engine() -> AsyncGenerator[Any, None]:
|
||||
# Use an in-memory SQLite database for testing
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def agent_db(db_engine: Any) -> AsyncGenerator[AgentDB, None]:
|
||||
yield AgentDB(database_string="sqlite+aiosqlite:///:memory:", debug_enabled=True, db_engine=db_engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_organization(agent_db: AgentDB) -> None:
|
||||
org_name = "Test Organization"
|
||||
domain = "test.com"
|
||||
organization = await agent_db.create_organization(organization_name=org_name, domain=domain)
|
||||
assert organization is not None
|
||||
assert organization.organization_name == org_name
|
||||
assert organization.domain == domain
|
||||
|
||||
retrieved_org = await agent_db.get_organization(organization.organization_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.organization_name == org_name
|
||||
assert retrieved_org.domain == domain
|
||||
|
||||
retrieved_by_domain = await agent_db.get_organization_by_domain(domain=domain)
|
||||
assert retrieved_by_domain is not None
|
||||
assert retrieved_by_domain.organization_name == org_name
|
||||
assert retrieved_by_domain.domain == domain
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_not_found(agent_db: AgentDB) -> None:
|
||||
retrieved_org = await agent_db.get_organization("non_existent_id")
|
||||
assert retrieved_org is None
|
||||
|
||||
retrieved_by_domain = await agent_db.get_organization_by_domain(domain="nonexistent.com")
|
||||
assert retrieved_by_domain is None
|
||||
Reference in New Issue
Block a user