From 93c45424896a2d1626470a04dae48829a83a06fa Mon Sep 17 00:00:00 2001 From: Luyang Wang Date: Sun, 28 Jun 2026 20:36:52 -0400 Subject: [PATCH] feat(sessions): add merge_state for race-free server-side atomic state merges Add a state-only write path, merge_state(app_name, user_id, session_id, delta), that performs a server-side atomic JSON merge of the routed delta into the app/user/session state row with no events row and without advancing the session optimistic-concurrency (OCC) marker. Today the only way to persist state is append_event, which couples a commutative state merge to event-log append + whole-session OCC. Two workers updating independent state keys on the same session therefore race on the session-level marker, and one spuriously fails with "modified in storage" and must retry, even though the keys never collide. merge_state routes the delta via the existing app:/user:/unprefixed convention (temp: rejected) and merges each scope server-side: - PostgreSQL: state || CAST(:delta AS JSONB) - MySQL/MariaDB: JSON_MERGE_PATCH(state, CAST(:delta AS JSON)) - SQLite: json_patch(state, :delta) The session-row merge is issued via sqlalchemy.text touching only the state column, so update_time's onupdate never fires and the OCC marker is preserved. App/user rows are auto-created via the existing _get_or_create_state; a missing session row raises for session-scoped keys. BaseSessionService and VertexAiSessionService raise NotImplementedError, mirroring get_user_state; InMemorySessionService and SqliteSessionService implement it natively. Adds parameterized tests across all backends plus DatabaseSessionService tests proving the OCC marker is not bumped, no events row is written, and concurrent merges to independent keys do not lose updates. --- .../adk/sessions/base_session_service.py | 61 ++++ .../adk/sessions/database_session_service.py | 172 ++++++++++ .../adk/sessions/in_memory_session_service.py | 40 +++ .../adk/sessions/sqlite_session_service.py | 73 ++++ .../adk/sessions/vertex_ai_session_service.py | 24 ++ .../sessions/test_session_service.py | 320 ++++++++++++++++++ .../test_vertex_ai_session_service.py | 12 + 7 files changed, 702 insertions(+) diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 06eb6a2534a..4487bf6342e 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -151,6 +151,67 @@ async def get_user_state( 'call get_session on each result to access the merged state.' ) + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """Atomically merges a state delta without appending an event. + + This is a state-only write path that bypasses the event log and the + whole-session optimistic-concurrency (OCC) check that ``append_event`` + performs. It is intended for *commutative* updates to independent state + keys (counters, flags, per-user balances, feature toggles), where coupling + the write to event-log append + session-level OCC would force unrelated + writers to serialize and retry on spurious "stale session" errors. + + The ``delta`` uses the same prefix convention as ``create_session`` and + event ``state_delta``; keys are routed to the scope-appropriate storage + row: + + * ``app:`` -> app-scoped state (keyed by ``app_name``). + * ``user:`` -> user-scoped state (keyed by ``(app_name, user_id)``). + * no prefix -> session-scoped state (keyed by + ``(app_name, user_id, session_id)``). + + ``temp:`` keys are rejected with ``ValueError`` because temp state is never + persisted. + + Guarantees (for backends that implement this method): + + * No ``events`` row is written. + * The session's OCC revision marker is **not** advanced, so a concurrently + held in-memory ``Session`` does not become stale and its next + ``append_event`` still succeeds. + * App- and user-scoped merges do not require a pre-existing session and are + fully decoupled from any session's revision. ``session_id`` is used only + to route session-scoped keys. + + Merge semantics for non-scalar values are backend-dependent; see the + concrete implementations (notably ``DatabaseSessionService.merge_state``) + for the dialect-specific caveats around nested objects and ``None`` values. + For flat scalar values all backends behave identically. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + session_id: The ID of the session. Required for routing session-scoped + keys; need not exist when the delta has no session-scoped keys. + delta: The state delta to merge. An empty or ``None`` delta is a no-op. + + Raises: + ValueError: If ``delta`` contains ``temp:``-prefixed keys, or if a + session-scoped key is supplied for a session that does not exist. + NotImplementedError: When the concrete ``BaseSessionService`` + implementation does not support a server-side state merge. + """ + raise NotImplementedError( + f'{type(self).__name__} does not support merge_state.' + ) + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 6c3572b8d66..4f71077c4d4 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -18,6 +18,7 @@ import copy from datetime import datetime from datetime import timezone +import json import logging from typing import Any from typing import AsyncIterator @@ -34,6 +35,7 @@ from sqlalchemy import event from sqlalchemy import MetaData from sqlalchemy import select + from sqlalchemy import text from sqlalchemy.engine import Connection from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError @@ -187,6 +189,30 @@ def _merge_state( return merged_state +def _json_merge_set_clause(dialect_name: str) -> Optional[str]: + """Returns a 'state = ' SET clause for an atomic server-side JSON merge. + + The returned SQL fragment merges a ``:delta`` bind parameter into the + existing ``state`` column in place. Only the ``state`` column is referenced, + so when issued via ``sqlalchemy.text`` no column ``onupdate`` callable (such + as the ``update_time`` revision marker) is triggered. + + Returns None for dialects that do not provide a known server-side JSON merge + function; the caller raises ``NotImplementedError`` in that case. + """ + if dialect_name == _POSTGRESQL_DIALECT: + # JSONB concatenation: shallow top-level merge; an explicit JSON null is + # stored (key kept), matching append_event's Python dict union. + return "state = state || CAST(:delta AS JSONB)" + if dialect_name in (_MYSQL_DIALECT, _MARIADB_DIALECT): + # RFC 7396: recursive merge; a JSON null deletes the key. + return "state = JSON_MERGE_PATCH(state, CAST(:delta AS JSON))" + if dialect_name == _SQLITE_DIALECT: + # RFC 7396 semantics, like the native SqliteSessionService. + return "state = json_patch(state, :delta)" + return None + + class _SchemaClasses: """A helper class to hold schema classes based on version.""" @@ -737,6 +763,152 @@ async def get_user_state( return {} return dict(storage_user_state.state or {}) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """Atomically merges a state delta server-side without appending an event. + + Each scoped sub-delta is merged into its storage row with a single atomic + SQL statement (PostgreSQL ``state || delta``, MySQL/MariaDB + ``JSON_MERGE_PATCH``, SQLite ``json_patch``) issued via ``text`` so that + only the ``state`` column is written. As a result: + + * No ``events`` row is created. + * ``sessions.update_time`` (the optimistic-concurrency revision marker) is + never advanced, so a concurrently held ``Session`` does not go stale. + * Independent keys merged concurrently do not lose updates, because the + merge happens inside the database rather than via Python + read-modify-write. + + Merge semantics for non-scalar values differ by dialect. On PostgreSQL the + top-level keys are merged shallowly and an explicit ``None`` is stored as + JSON ``null`` (the key is kept), matching ``append_event``'s Python ``dict`` + union. On MySQL, MariaDB and SQLite the merge follows RFC 7396: nested + objects are merged recursively and a ``None`` value deletes the key. For + flat scalar values (counters, flags, balances -- the intended use case) all + dialects behave identically. Use ``append_event`` if you need uniform + Python-``dict``-union semantics for nested objects or ``None`` overwrites. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + session_id: The ID of the session. Used to route session-scoped (no + prefix) keys; need not exist when ``delta`` has no session-scoped keys. + delta: The state delta to merge, using the ``app:``/``user:``/no-prefix + convention. An empty or ``None`` delta is a no-op. + + Raises: + ValueError: If ``delta`` contains ``temp:`` keys, or if a session-scoped + key targets a session that does not exist. + NotImplementedError: If the active SQL dialect provides no known + server-side JSON merge function. + """ + await self.prepare_tables() + if not delta: + return + if any(key.startswith(State.TEMP_PREFIX) for key in delta): + raise ValueError( + "merge_state does not support temp: keys; temp state is never" + " persisted." + ) + + merge_clause = _json_merge_set_clause(self.db_engine.dialect.name) + if merge_clause is None: + raise ValueError( + "merge_state is not supported for dialect" + f" {self.db_engine.dialect.name!r}: no server-side JSON merge" + " function is available." + ) + + state_deltas = _session_util.extract_state_delta(delta) + app_delta = state_deltas["app"] + user_delta = state_deltas["user"] + session_delta = state_deltas["session"] + + schema = self._get_schema_classes() + app_table = schema.StorageAppState.__tablename__ + user_table = schema.StorageUserState.__tablename__ + session_table = schema.StorageSession.__tablename__ + use_row_level_locking = self._supports_row_level_locking() + + async with self._with_session_lock( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ): + async with self._rollback_on_exception_session() as sql_session: + if session_delta: + # A session-scoped key requires the session row to exist; never + # auto-create it (that is create_session's responsibility). + session_exists_stmt = ( + select(schema.StorageSession.id) + .filter(schema.StorageSession.app_name == app_name) + .filter(schema.StorageSession.user_id == user_id) + .filter(schema.StorageSession.id == session_id) + ) + if use_row_level_locking: + session_exists_stmt = session_exists_stmt.with_for_update() + session_exists = await sql_session.execute(session_exists_stmt) + if session_exists.scalar_one_or_none() is None: + raise ValueError(f"Session {session_id} not found.") + + if app_delta: + # Ensure the row exists (handles concurrent inserts), then merge + # server-side. App/user rows have their own update_time which is not + # an OCC marker, so merging here is independent of any session. + await _get_or_create_state( + sql_session=sql_session, + state_model=schema.StorageAppState, + primary_key=app_name, + defaults={"app_name": app_name, "state": {}}, + ) + await sql_session.execute( + text( + f"UPDATE {app_table} SET {merge_clause} WHERE" + " app_name = :app_name" + ), + {"app_name": app_name, "delta": json.dumps(app_delta)}, + ) + if user_delta: + await _get_or_create_state( + sql_session=sql_session, + state_model=schema.StorageUserState, + primary_key=(app_name, user_id), + defaults={"app_name": app_name, "user_id": user_id, "state": {}}, + ) + await sql_session.execute( + text( + f"UPDATE {user_table} SET {merge_clause} WHERE" + " app_name = :app_name AND user_id = :user_id" + ), + { + "app_name": app_name, + "user_id": user_id, + "delta": json.dumps(user_delta), + }, + ) + if session_delta: + await sql_session.execute( + text( + f"UPDATE {session_table} SET {merge_clause} WHERE" + " app_name = :app_name AND user_id = :user_id" + " AND id = :session_id" + ), + { + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "delta": json.dumps(session_delta), + }, + ) + await sql_session.commit() + @override async def append_event(self, session: Session, event: Event) -> Event: await self.prepare_tables() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 73a54f398b8..9e861a2c96c 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -318,6 +318,46 @@ async def get_user_state( ) -> dict[str, Any]: return dict(self.user_state.get(app_name, {}).get(user_id, {})) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + if not delta: + return + if any(key.startswith(State.TEMP_PREFIX) for key in delta): + raise ValueError( + 'merge_state does not support temp: keys; temp state is never' + ' persisted.' + ) + + state_deltas = _session_util.extract_state_delta(delta) + app_state_delta = state_deltas['app'] + user_state_delta = state_deltas['user'] + session_state_delta = state_deltas['session'] + + if session_state_delta: + storage_session = ( + self.sessions.get(app_name, {}).get(user_id, {}).get(session_id) + ) + if storage_session is None: + raise ValueError(f'Session {session_id} not found.') + + if app_state_delta: + self.app_state.setdefault(app_name, {}).update(app_state_delta) + if user_state_delta: + self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update( + user_state_delta + ) + if session_state_delta: + # Merge into the stored session's state without bumping + # last_update_time, so a concurrently held session does not go stale. + storage_session.state.update(session_state_delta) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d0d699e4c3a..1835fbad0e5 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -367,6 +367,54 @@ async def get_user_state( async with self._get_db_connection() as db: return await self._get_user_state(db, app_name, user_id) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + if not delta: + return + if any(key.startswith(State.TEMP_PREFIX) for key in delta): + raise ValueError( + "merge_state does not support temp: keys; temp state is never" + " persisted." + ) + + state_deltas = _session_util.extract_state_delta(delta) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state_delta = state_deltas["session"] + now = platform_time.get_time() + + async with self._get_db_connection() as db: + if session_state_delta: + async with db.execute( + "SELECT 1 FROM sessions WHERE app_name=? AND user_id=? AND id=?", + (app_name, user_id, session_id), + ) as cursor: + if await cursor.fetchone() is None: + raise ValueError(f"Session {session_id} not found.") + + # Each merge below uses an atomic json_patch on the storage row, so no + # read-modify-write and no whole-session OCC check is needed. + if app_state_delta: + await self._upsert_app_state(db, app_name, app_state_delta, now) + if user_state_delta: + await self._upsert_user_state( + db, app_name, user_id, user_state_delta, now + ) + if session_state_delta: + # Merge the session state WITHOUT bumping sessions.update_time, so a + # concurrently held session does not go stale on its next append_event. + await self._merge_session_state_in_db( + db, app_name, user_id, session_id, session_state_delta + ) + await db.commit() + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: @@ -567,6 +615,31 @@ async def _update_session_state_in_db( ), ) + async def _merge_session_state_in_db( + self, + db: aiosqlite.Connection, + app_name: str, + user_id: str, + session_id: str, + delta: dict, + ) -> None: + """Atomically merges session state via json_patch without bumping update_time. + + Unlike _update_session_state_in_db, this intentionally leaves + sessions.update_time untouched so that merge_state does not advance the + optimistic-concurrency marker derived from it. + """ + await db.execute( + "UPDATE sessions SET state=json_patch(state, ?) WHERE" + " app_name=? AND user_id=? AND id=?", + ( + json.dumps(delta), + app_name, + user_id, + session_id, + ), + ) + def _is_migration_needed(self) -> bool: """Checks if migration to new schema is needed.""" if not os.path.exists(self._db_path): diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 128e101f164..251d765af10 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -381,6 +381,30 @@ async def get_user_state( 'via list_sessions and call get_session on each result.' ) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """Not supported by the Vertex AI Agent Engine backend. + + The Vertex AI Agent Engine API does not expose a server-side state merge + independent of the event log. To persist state, append an event with a + ``state_delta`` via ``append_event``. + + Raises: + NotImplementedError: Always. + """ + raise NotImplementedError( + 'VertexAiSessionService does not support merge_state. ' + 'To persist state, append an event with a state_delta via ' + 'append_event.' + ) + @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 157e4fb21aa..bbdd7066c55 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -34,6 +34,7 @@ from google.genai import types import pytest from sqlalchemy import delete +from sqlalchemy import select from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import StaticPool @@ -597,6 +598,181 @@ async def test_temp_state_visible_across_sequential_events(session_service): assert 'temp:output' not in event1.actions.state_delta +@pytest.mark.asyncio +async def test_merge_state_merges_session_scoped_key(session_service): + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s1', state={'sk1': 'v1'} + ) + + await session_service.merge_state( + app_name=app_name, + user_id=user_id, + session_id=session.id, + delta={'sk2': 'v2'}, + ) + + got = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got.state.get('sk1') == 'v1' + assert got.state.get('sk2') == 'v2' + # merge_state must not write an event. + assert got.events == [] + + +@pytest.mark.asyncio +async def test_merge_state_merges_app_scoped_key(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'app:flag': True}, + ) + + # Visible to the same session and to a different user's session. + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state.get('app:flag') is True + other = await session_service.create_session( + app_name=app_name, user_id='u2', session_id='s2' + ) + assert other.state.get('app:flag') is True + + +@pytest.mark.asyncio +async def test_merge_state_merges_user_scoped_key(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'user:tier': 'gold'}, + ) + + assert await session_service.get_user_state( + app_name=app_name, user_id='u1' + ) == {'tier': 'gold'} + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state.get('user:tier') == 'gold' + + +@pytest.mark.asyncio +async def test_merge_state_cross_scope_routing(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'app:a': 1, 'user:b': 2, 'sk': 3}, + ) + + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state.get('app:a') == 1 + assert got.state.get('user:b') == 2 + assert got.state.get('sk') == 3 + + +@pytest.mark.asyncio +async def test_merge_state_rejects_temp_keys(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + with pytest.raises(ValueError, match='temp:'): + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'temp:x': 1}, + ) + # Mixed delta with a temp: key is rejected entirely (nothing persisted). + with pytest.raises(ValueError, match='temp:'): + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'sk': 1, 'temp:x': 2}, + ) + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert 'sk' not in got.state + + +@pytest.mark.asyncio +async def test_merge_state_empty_delta_is_noop(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'} + ) + + await session_service.merge_state( + app_name=app_name, user_id='u1', session_id=session.id, delta={} + ) + + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state == {'sk1': 'v1'} + assert got.events == [] + + +@pytest.mark.asyncio +async def test_merge_state_missing_session_raises_for_session_delta( + session_service, +): + with pytest.raises(ValueError, match='not found'): + await session_service.merge_state( + app_name='my_app', + user_id='u1', + session_id='does_not_exist', + delta={'sk': 1}, + ) + + +@pytest.mark.asyncio +async def test_merge_state_auto_creates_app_user_rows(session_service): + # No create_session for this (app, user); app/user merges should still work + # and not require a session to exist. + app_name = 'fresh_app' + await session_service.merge_state( + app_name=app_name, + user_id='fresh_user', + session_id='unused', + delta={'app:x': 1, 'user:y': 2}, + ) + + assert await session_service.get_user_state( + app_name=app_name, user_id='fresh_user' + ) == {'y': 2} + # The app-scoped value is observable via a freshly created session. + session = await session_service.create_session( + app_name=app_name, user_id='fresh_user', session_id='s1' + ) + assert session.state.get('app:x') == 1 + + @pytest.mark.asyncio async def test_get_session_respects_user_id(session_service): app_name = 'my_app' @@ -823,6 +999,150 @@ async def test_append_event_to_stale_session(): ] +@pytest.mark.asyncio +async def test_merge_state_does_not_bump_occ_marker_for_session_scope(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + original_marker = session._storage_update_marker + + await service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={'sk': 'v'}, + ) + + # The stored OCC marker must be unchanged by the merge. + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + storage_session = await sql_session.get( + schema.StorageSession, ('my_app', 'user', 's1') + ) + assert storage_session.get_update_marker() == original_marker + assert storage_session.state.get('sk') == 'v' + + # The originally-held session is NOT stale: it can still append. + event = Event( + invocation_id='inv1', + author='user', + timestamp=session.last_update_time + 10, + ) + await service.append_event(session, event) + + final = await service.get_session( + app_name='my_app', user_id='user', session_id='s1' + ) + assert final.state.get('sk') == 'v' + assert [e.invocation_id for e in final.events] == ['inv1'] + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_merge_state_does_not_bump_occ_marker_for_app_user_scope(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + original_marker = session._storage_update_marker + + await service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={'app:a': 1, 'user:b': 2}, + ) + + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + storage_session = await sql_session.get( + schema.StorageSession, ('my_app', 'user', 's1') + ) + assert storage_session.get_update_marker() == original_marker + + # The held session can still append after app/user merges. + event = Event( + invocation_id='inv1', + author='user', + timestamp=session.last_update_time + 10, + ) + await service.append_event(session, event) + + final = await service.get_session( + app_name='my_app', user_id='user', session_id='s1' + ) + assert final.state.get('app:a') == 1 + assert final.state.get('user:b') == 2 + assert [e.invocation_id for e in final.events] == ['inv1'] + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_merge_state_no_events_row_created(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + + await service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={'app:a': 1, 'user:b': 2, 'sk': 3}, + ) + + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + result = await sql_session.execute(select(schema.StorageEvent)) + assert result.scalars().all() == [] + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_merge_state_concurrent_independent_keys_no_lost_update(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + + iteration_count = 8 + for i in range(iteration_count): + await asyncio.gather( + service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={f'sk{i}-1': f'v{i}-1'}, + ), + service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={f'sk{i}-2': f'v{i}-2'}, + ), + ) + + final = await service.get_session( + app_name='my_app', user_id='user', session_id='s1' + ) + # Both independent keys from every iteration must be present (no lost + # update), and no event row was created. + for i in range(iteration_count): + assert final.state.get(f'sk{i}-1') == f'v{i}-1' + assert final.state.get(f'sk{i}-2') == f'v{i}-2' + assert final.events == [] + finally: + await service.close() + + @pytest.mark.asyncio async def test_append_event_raises_if_app_state_row_missing(): service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 7f3cb61a05b..1cbdbd77d22 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -656,6 +656,18 @@ async def test_initialize_with_project_location_and_api_key_error(): ) +@pytest.mark.asyncio +async def test_merge_state_not_implemented(): + session_service = mock_vertex_ai_session_service() + with pytest.raises(NotImplementedError, match='merge_state'): + await session_service.merge_state( + app_name='123', + user_id='user', + session_id='1', + delta={'k': 'v'}, + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_get_session_returns_none_when_invalid_argument(