Add some basic tests to AgentDB. (#2666)
This commit is contained in:
52
skyvern/forge/sdk/db/client_test.py
Normal file
52
skyvern/forge/sdk/db/client_test.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from skyvern.forge.sdk.db.client import AgentDB
|
||||
from skyvern.forge.sdk.db.models import Base
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_engine() -> AsyncGenerator[Any, None]:
|
||||
# Use an in-memory SQLite database for testing
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def agent_db(db_engine: Any) -> AsyncGenerator[AgentDB, None]:
|
||||
yield AgentDB(database_string="sqlite+aiosqlite:///:memory:", debug_enabled=True, db_engine=db_engine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_organization(agent_db: AgentDB) -> None:
|
||||
org_name = "Test Organization"
|
||||
domain = "test.com"
|
||||
organization = await agent_db.create_organization(organization_name=org_name, domain=domain)
|
||||
assert organization is not None
|
||||
assert organization.organization_name == org_name
|
||||
assert organization.domain == domain
|
||||
|
||||
retrieved_org = await agent_db.get_organization(organization.organization_id)
|
||||
assert retrieved_org is not None
|
||||
assert retrieved_org.organization_name == org_name
|
||||
assert retrieved_org.domain == domain
|
||||
|
||||
retrieved_by_domain = await agent_db.get_organization_by_domain(domain=domain)
|
||||
assert retrieved_by_domain is not None
|
||||
assert retrieved_by_domain.organization_name == org_name
|
||||
assert retrieved_by_domain.domain == domain
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_organization_not_found(agent_db: AgentDB) -> None:
|
||||
retrieved_org = await agent_db.get_organization("non_existent_id")
|
||||
assert retrieved_org is None
|
||||
|
||||
retrieved_by_domain = await agent_db.get_organization_by_domain(domain="nonexistent.com")
|
||||
assert retrieved_by_domain is None
|
||||
Reference in New Issue
Block a user