From 0ddb7e02863ac7f40e958aa976b756441205619c Mon Sep 17 00:00:00 2001 From: Asher Foa Date: Tue, 10 Jun 2025 12:04:11 -0400 Subject: [PATCH] Add some basic tests to AgentDB. (#2666) --- .github/workflows/ci.yml | 2 +- poetry.lock | 23 +++++++++++-- pyproject.toml | 1 + skyvern/forge/sdk/db/client_test.py | 52 +++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 skyvern/forge/sdk/db/client_test.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 03bd789d..64b549c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -107,7 +107,7 @@ jobs: env: ENABLE_OPENAI: "true" OPENAI_API_KEY: "sk-dummy" - run: poetry run pytest tests + run: poetry run pytest fe-lint-build: name: Frontend Lint and Build runs-on: ubuntu-latest diff --git a/poetry.lock b/poetry.lock index 7aa49076..5ab01b22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "about-time" @@ -250,6 +250,25 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" +[[package]] +name = "aiosqlite" +version = "0.21.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"}, + {file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"}, +] + +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"] + [[package]] name = "alembic" version = "1.15.2" @@ -7739,4 +7758,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "36792c7985a25a7b601fbd12db04df45508749c81090acfd2860c28884bf0e9a" +content-hash = "4665c6fa560799864feeb4eb1c86dec10957a21373add46d8b531532cec616f3" diff --git a/pyproject.toml b/pyproject.toml index 32121fa6..3a778cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ build = "^1.2.2.post1" pandas = "^2.2.3" pre-commit = "^4.2.0" ruff = "^0.11.12" +aiosqlite = "^0.21.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/skyvern/forge/sdk/db/client_test.py b/skyvern/forge/sdk/db/client_test.py new file mode 100644 index 00000000..870fc350 --- /dev/null +++ b/skyvern/forge/sdk/db/client_test.py @@ -0,0 +1,52 @@ +from typing import Any, AsyncGenerator + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import create_async_engine + +from skyvern.forge.sdk.db.client import AgentDB +from skyvern.forge.sdk.db.models import Base + + +@pytest_asyncio.fixture +async def db_engine() -> AsyncGenerator[Any, None]: + # Use an in-memory SQLite database for testing + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def agent_db(db_engine: Any) -> AsyncGenerator[AgentDB, None]: + yield AgentDB(database_string="sqlite+aiosqlite:///:memory:", debug_enabled=True, db_engine=db_engine) + + +@pytest.mark.asyncio +async def test_create_organization(agent_db: AgentDB) -> None: + org_name = "Test Organization" + domain = "test.com" + organization = await agent_db.create_organization(organization_name=org_name, domain=domain) + assert organization is not None + assert organization.organization_name == org_name + assert organization.domain == domain + + retrieved_org = await agent_db.get_organization(organization.organization_id) + assert retrieved_org is not None + assert retrieved_org.organization_name == org_name + assert retrieved_org.domain == domain + + retrieved_by_domain = await agent_db.get_organization_by_domain(domain=domain) + assert retrieved_by_domain is not None + assert retrieved_by_domain.organization_name == org_name + assert retrieved_by_domain.domain == domain + + +@pytest.mark.asyncio +async def test_get_organization_not_found(agent_db: AgentDB) -> None: + retrieved_org = await agent_db.get_organization("non_existent_id") + assert retrieved_org is None + + retrieved_by_domain = await agent_db.get_organization_by_domain(domain="nonexistent.com") + assert retrieved_by_domain is None