diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index 06eb6a2534..4487bf6342 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -151,6 +151,67 @@ async def get_user_state( 'call get_session on each result to access the merged state.' ) + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """Atomically merges a state delta without appending an event. + + This is a state-only write path that bypasses the event log and the + whole-session optimistic-concurrency (OCC) check that ``append_event`` + performs. It is intended for *commutative* updates to independent state + keys (counters, flags, per-user balances, feature toggles), where coupling + the write to event-log append + session-level OCC would force unrelated + writers to serialize and retry on spurious "stale session" errors. + + The ``delta`` uses the same prefix convention as ``create_session`` and + event ``state_delta``; keys are routed to the scope-appropriate storage + row: + + * ``app:`` -> app-scoped state (keyed by ``app_name``). + * ``user:`` -> user-scoped state (keyed by ``(app_name, user_id)``). + * no prefix -> session-scoped state (keyed by + ``(app_name, user_id, session_id)``). + + ``temp:`` keys are rejected with ``ValueError`` because temp state is never + persisted. + + Guarantees (for backends that implement this method): + + * No ``events`` row is written. + * The session's OCC revision marker is **not** advanced, so a concurrently + held in-memory ``Session`` does not become stale and its next + ``append_event`` still succeeds. + * App- and user-scoped merges do not require a pre-existing session and are + fully decoupled from any session's revision. ``session_id`` is used only + to route session-scoped keys. + + Merge semantics for non-scalar values are backend-dependent; see the + concrete implementations (notably ``DatabaseSessionService.merge_state``) + for the dialect-specific caveats around nested objects and ``None`` values. + For flat scalar values all backends behave identically. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + session_id: The ID of the session. Required for routing session-scoped + keys; need not exist when the delta has no session-scoped keys. + delta: The state delta to merge. An empty or ``None`` delta is a no-op. + + Raises: + ValueError: If ``delta`` contains ``temp:``-prefixed keys, or if a + session-scoped key is supplied for a session that does not exist. + NotImplementedError: When the concrete ``BaseSessionService`` + implementation does not support a server-side state merge. + """ + raise NotImplementedError( + f'{type(self).__name__} does not support merge_state.' + ) + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 6c3572b8d6..4f71077c4d 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -18,6 +18,7 @@ import copy from datetime import datetime from datetime import timezone +import json import logging from typing import Any from typing import AsyncIterator @@ -34,6 +35,7 @@ from sqlalchemy import event from sqlalchemy import MetaData from sqlalchemy import select + from sqlalchemy import text from sqlalchemy.engine import Connection from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError @@ -187,6 +189,30 @@ def _merge_state( return merged_state +def _json_merge_set_clause(dialect_name: str) -> Optional[str]: + """Returns a 'state = ' SET clause for an atomic server-side JSON merge. + + The returned SQL fragment merges a ``:delta`` bind parameter into the + existing ``state`` column in place. Only the ``state`` column is referenced, + so when issued via ``sqlalchemy.text`` no column ``onupdate`` callable (such + as the ``update_time`` revision marker) is triggered. + + Returns None for dialects that do not provide a known server-side JSON merge + function; the caller raises ``NotImplementedError`` in that case. + """ + if dialect_name == _POSTGRESQL_DIALECT: + # JSONB concatenation: shallow top-level merge; an explicit JSON null is + # stored (key kept), matching append_event's Python dict union. + return "state = state || CAST(:delta AS JSONB)" + if dialect_name in (_MYSQL_DIALECT, _MARIADB_DIALECT): + # RFC 7396: recursive merge; a JSON null deletes the key. + return "state = JSON_MERGE_PATCH(state, CAST(:delta AS JSON))" + if dialect_name == _SQLITE_DIALECT: + # RFC 7396 semantics, like the native SqliteSessionService. + return "state = json_patch(state, :delta)" + return None + + class _SchemaClasses: """A helper class to hold schema classes based on version.""" @@ -737,6 +763,152 @@ async def get_user_state( return {} return dict(storage_user_state.state or {}) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """Atomically merges a state delta server-side without appending an event. + + Each scoped sub-delta is merged into its storage row with a single atomic + SQL statement (PostgreSQL ``state || delta``, MySQL/MariaDB + ``JSON_MERGE_PATCH``, SQLite ``json_patch``) issued via ``text`` so that + only the ``state`` column is written. As a result: + + * No ``events`` row is created. + * ``sessions.update_time`` (the optimistic-concurrency revision marker) is + never advanced, so a concurrently held ``Session`` does not go stale. + * Independent keys merged concurrently do not lose updates, because the + merge happens inside the database rather than via Python + read-modify-write. + + Merge semantics for non-scalar values differ by dialect. On PostgreSQL the + top-level keys are merged shallowly and an explicit ``None`` is stored as + JSON ``null`` (the key is kept), matching ``append_event``'s Python ``dict`` + union. On MySQL, MariaDB and SQLite the merge follows RFC 7396: nested + objects are merged recursively and a ``None`` value deletes the key. For + flat scalar values (counters, flags, balances -- the intended use case) all + dialects behave identically. Use ``append_event`` if you need uniform + Python-``dict``-union semantics for nested objects or ``None`` overwrites. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + session_id: The ID of the session. Used to route session-scoped (no + prefix) keys; need not exist when ``delta`` has no session-scoped keys. + delta: The state delta to merge, using the ``app:``/``user:``/no-prefix + convention. An empty or ``None`` delta is a no-op. + + Raises: + ValueError: If ``delta`` contains ``temp:`` keys, or if a session-scoped + key targets a session that does not exist. + NotImplementedError: If the active SQL dialect provides no known + server-side JSON merge function. + """ + await self.prepare_tables() + if not delta: + return + if any(key.startswith(State.TEMP_PREFIX) for key in delta): + raise ValueError( + "merge_state does not support temp: keys; temp state is never" + " persisted." + ) + + merge_clause = _json_merge_set_clause(self.db_engine.dialect.name) + if merge_clause is None: + raise ValueError( + "merge_state is not supported for dialect" + f" {self.db_engine.dialect.name!r}: no server-side JSON merge" + " function is available." + ) + + state_deltas = _session_util.extract_state_delta(delta) + app_delta = state_deltas["app"] + user_delta = state_deltas["user"] + session_delta = state_deltas["session"] + + schema = self._get_schema_classes() + app_table = schema.StorageAppState.__tablename__ + user_table = schema.StorageUserState.__tablename__ + session_table = schema.StorageSession.__tablename__ + use_row_level_locking = self._supports_row_level_locking() + + async with self._with_session_lock( + app_name=app_name, + user_id=user_id, + session_id=session_id, + ): + async with self._rollback_on_exception_session() as sql_session: + if session_delta: + # A session-scoped key requires the session row to exist; never + # auto-create it (that is create_session's responsibility). + session_exists_stmt = ( + select(schema.StorageSession.id) + .filter(schema.StorageSession.app_name == app_name) + .filter(schema.StorageSession.user_id == user_id) + .filter(schema.StorageSession.id == session_id) + ) + if use_row_level_locking: + session_exists_stmt = session_exists_stmt.with_for_update() + session_exists = await sql_session.execute(session_exists_stmt) + if session_exists.scalar_one_or_none() is None: + raise ValueError(f"Session {session_id} not found.") + + if app_delta: + # Ensure the row exists (handles concurrent inserts), then merge + # server-side. App/user rows have their own update_time which is not + # an OCC marker, so merging here is independent of any session. + await _get_or_create_state( + sql_session=sql_session, + state_model=schema.StorageAppState, + primary_key=app_name, + defaults={"app_name": app_name, "state": {}}, + ) + await sql_session.execute( + text( + f"UPDATE {app_table} SET {merge_clause} WHERE" + " app_name = :app_name" + ), + {"app_name": app_name, "delta": json.dumps(app_delta)}, + ) + if user_delta: + await _get_or_create_state( + sql_session=sql_session, + state_model=schema.StorageUserState, + primary_key=(app_name, user_id), + defaults={"app_name": app_name, "user_id": user_id, "state": {}}, + ) + await sql_session.execute( + text( + f"UPDATE {user_table} SET {merge_clause} WHERE" + " app_name = :app_name AND user_id = :user_id" + ), + { + "app_name": app_name, + "user_id": user_id, + "delta": json.dumps(user_delta), + }, + ) + if session_delta: + await sql_session.execute( + text( + f"UPDATE {session_table} SET {merge_clause} WHERE" + " app_name = :app_name AND user_id = :user_id" + " AND id = :session_id" + ), + { + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "delta": json.dumps(session_delta), + }, + ) + await sql_session.commit() + @override async def append_event(self, session: Session, event: Event) -> Event: await self.prepare_tables() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 73a54f398b..9e861a2c96 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -318,6 +318,46 @@ async def get_user_state( ) -> dict[str, Any]: return dict(self.user_state.get(app_name, {}).get(user_id, {})) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + if not delta: + return + if any(key.startswith(State.TEMP_PREFIX) for key in delta): + raise ValueError( + 'merge_state does not support temp: keys; temp state is never' + ' persisted.' + ) + + state_deltas = _session_util.extract_state_delta(delta) + app_state_delta = state_deltas['app'] + user_state_delta = state_deltas['user'] + session_state_delta = state_deltas['session'] + + if session_state_delta: + storage_session = ( + self.sessions.get(app_name, {}).get(user_id, {}).get(session_id) + ) + if storage_session is None: + raise ValueError(f'Session {session_id} not found.') + + if app_state_delta: + self.app_state.setdefault(app_name, {}).update(app_state_delta) + if user_state_delta: + self.user_state.setdefault(app_name, {}).setdefault(user_id, {}).update( + user_state_delta + ) + if session_state_delta: + # Merge into the stored session's state without bumping + # last_update_time, so a concurrently held session does not go stale. + storage_session.state.update(session_state_delta) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index d0d699e4c3..1835fbad0e 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -367,6 +367,54 @@ async def get_user_state( async with self._get_db_connection() as db: return await self._get_user_state(db, app_name, user_id) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + if not delta: + return + if any(key.startswith(State.TEMP_PREFIX) for key in delta): + raise ValueError( + "merge_state does not support temp: keys; temp state is never" + " persisted." + ) + + state_deltas = _session_util.extract_state_delta(delta) + app_state_delta = state_deltas["app"] + user_state_delta = state_deltas["user"] + session_state_delta = state_deltas["session"] + now = platform_time.get_time() + + async with self._get_db_connection() as db: + if session_state_delta: + async with db.execute( + "SELECT 1 FROM sessions WHERE app_name=? AND user_id=? AND id=?", + (app_name, user_id, session_id), + ) as cursor: + if await cursor.fetchone() is None: + raise ValueError(f"Session {session_id} not found.") + + # Each merge below uses an atomic json_patch on the storage row, so no + # read-modify-write and no whole-session OCC check is needed. + if app_state_delta: + await self._upsert_app_state(db, app_name, app_state_delta, now) + if user_state_delta: + await self._upsert_user_state( + db, app_name, user_id, user_state_delta, now + ) + if session_state_delta: + # Merge the session state WITHOUT bumping sessions.update_time, so a + # concurrently held session does not go stale on its next append_event. + await self._merge_session_state_in_db( + db, app_name, user_id, session_id, session_state_delta + ) + await db.commit() + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: @@ -567,6 +615,31 @@ async def _update_session_state_in_db( ), ) + async def _merge_session_state_in_db( + self, + db: aiosqlite.Connection, + app_name: str, + user_id: str, + session_id: str, + delta: dict, + ) -> None: + """Atomically merges session state via json_patch without bumping update_time. + + Unlike _update_session_state_in_db, this intentionally leaves + sessions.update_time untouched so that merge_state does not advance the + optimistic-concurrency marker derived from it. + """ + await db.execute( + "UPDATE sessions SET state=json_patch(state, ?) WHERE" + " app_name=? AND user_id=? AND id=?", + ( + json.dumps(delta), + app_name, + user_id, + session_id, + ), + ) + def _is_migration_needed(self) -> bool: """Checks if migration to new schema is needed.""" if not os.path.exists(self._db_path): diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 128e101f16..251d765af1 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -381,6 +381,30 @@ async def get_user_state( 'via list_sessions and call get_session on each result.' ) + @override + async def merge_state( + self, + *, + app_name: str, + user_id: str, + session_id: str, + delta: dict[str, Any], + ) -> None: + """Not supported by the Vertex AI Agent Engine backend. + + The Vertex AI Agent Engine API does not expose a server-side state merge + independent of the event log. To persist state, append an event with a + ``state_delta`` via ``append_event``. + + Raises: + NotImplementedError: Always. + """ + raise NotImplementedError( + 'VertexAiSessionService does not support merge_state. ' + 'To persist state, append an event with a state_delta via ' + 'append_event.' + ) + @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 157e4fb21a..bbdd7066c5 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -34,6 +34,7 @@ from google.genai import types import pytest from sqlalchemy import delete +from sqlalchemy import select from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import StaticPool @@ -597,6 +598,181 @@ async def test_temp_state_visible_across_sequential_events(session_service): assert 'temp:output' not in event1.actions.state_delta +@pytest.mark.asyncio +async def test_merge_state_merges_session_scoped_key(session_service): + app_name = 'my_app' + user_id = 'u1' + session = await session_service.create_session( + app_name=app_name, user_id=user_id, session_id='s1', state={'sk1': 'v1'} + ) + + await session_service.merge_state( + app_name=app_name, + user_id=user_id, + session_id=session.id, + delta={'sk2': 'v2'}, + ) + + got = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got.state.get('sk1') == 'v1' + assert got.state.get('sk2') == 'v2' + # merge_state must not write an event. + assert got.events == [] + + +@pytest.mark.asyncio +async def test_merge_state_merges_app_scoped_key(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'app:flag': True}, + ) + + # Visible to the same session and to a different user's session. + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state.get('app:flag') is True + other = await session_service.create_session( + app_name=app_name, user_id='u2', session_id='s2' + ) + assert other.state.get('app:flag') is True + + +@pytest.mark.asyncio +async def test_merge_state_merges_user_scoped_key(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'user:tier': 'gold'}, + ) + + assert await session_service.get_user_state( + app_name=app_name, user_id='u1' + ) == {'tier': 'gold'} + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state.get('user:tier') == 'gold' + + +@pytest.mark.asyncio +async def test_merge_state_cross_scope_routing(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'app:a': 1, 'user:b': 2, 'sk': 3}, + ) + + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state.get('app:a') == 1 + assert got.state.get('user:b') == 2 + assert got.state.get('sk') == 3 + + +@pytest.mark.asyncio +async def test_merge_state_rejects_temp_keys(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + + with pytest.raises(ValueError, match='temp:'): + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'temp:x': 1}, + ) + # Mixed delta with a temp: key is rejected entirely (nothing persisted). + with pytest.raises(ValueError, match='temp:'): + await session_service.merge_state( + app_name=app_name, + user_id='u1', + session_id=session.id, + delta={'sk': 1, 'temp:x': 2}, + ) + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert 'sk' not in got.state + + +@pytest.mark.asyncio +async def test_merge_state_empty_delta_is_noop(session_service): + app_name = 'my_app' + session = await session_service.create_session( + app_name=app_name, user_id='u1', session_id='s1', state={'sk1': 'v1'} + ) + + await session_service.merge_state( + app_name=app_name, user_id='u1', session_id=session.id, delta={} + ) + + got = await session_service.get_session( + app_name=app_name, user_id='u1', session_id='s1' + ) + assert got.state == {'sk1': 'v1'} + assert got.events == [] + + +@pytest.mark.asyncio +async def test_merge_state_missing_session_raises_for_session_delta( + session_service, +): + with pytest.raises(ValueError, match='not found'): + await session_service.merge_state( + app_name='my_app', + user_id='u1', + session_id='does_not_exist', + delta={'sk': 1}, + ) + + +@pytest.mark.asyncio +async def test_merge_state_auto_creates_app_user_rows(session_service): + # No create_session for this (app, user); app/user merges should still work + # and not require a session to exist. + app_name = 'fresh_app' + await session_service.merge_state( + app_name=app_name, + user_id='fresh_user', + session_id='unused', + delta={'app:x': 1, 'user:y': 2}, + ) + + assert await session_service.get_user_state( + app_name=app_name, user_id='fresh_user' + ) == {'y': 2} + # The app-scoped value is observable via a freshly created session. + session = await session_service.create_session( + app_name=app_name, user_id='fresh_user', session_id='s1' + ) + assert session.state.get('app:x') == 1 + + @pytest.mark.asyncio async def test_get_session_respects_user_id(session_service): app_name = 'my_app' @@ -823,6 +999,150 @@ async def test_append_event_to_stale_session(): ] +@pytest.mark.asyncio +async def test_merge_state_does_not_bump_occ_marker_for_session_scope(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + original_marker = session._storage_update_marker + + await service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={'sk': 'v'}, + ) + + # The stored OCC marker must be unchanged by the merge. + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + storage_session = await sql_session.get( + schema.StorageSession, ('my_app', 'user', 's1') + ) + assert storage_session.get_update_marker() == original_marker + assert storage_session.state.get('sk') == 'v' + + # The originally-held session is NOT stale: it can still append. + event = Event( + invocation_id='inv1', + author='user', + timestamp=session.last_update_time + 10, + ) + await service.append_event(session, event) + + final = await service.get_session( + app_name='my_app', user_id='user', session_id='s1' + ) + assert final.state.get('sk') == 'v' + assert [e.invocation_id for e in final.events] == ['inv1'] + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_merge_state_does_not_bump_occ_marker_for_app_user_scope(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + original_marker = session._storage_update_marker + + await service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={'app:a': 1, 'user:b': 2}, + ) + + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + storage_session = await sql_session.get( + schema.StorageSession, ('my_app', 'user', 's1') + ) + assert storage_session.get_update_marker() == original_marker + + # The held session can still append after app/user merges. + event = Event( + invocation_id='inv1', + author='user', + timestamp=session.last_update_time + 10, + ) + await service.append_event(session, event) + + final = await service.get_session( + app_name='my_app', user_id='user', session_id='s1' + ) + assert final.state.get('app:a') == 1 + assert final.state.get('user:b') == 2 + assert [e.invocation_id for e in final.events] == ['inv1'] + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_merge_state_no_events_row_created(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + + await service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={'app:a': 1, 'user:b': 2, 'sk': 3}, + ) + + schema = service._get_schema_classes() + async with service.database_session_factory() as sql_session: + result = await sql_session.execute(select(schema.StorageEvent)) + assert result.scalars().all() == [] + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_merge_state_concurrent_independent_keys_no_lost_update(): + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + + iteration_count = 8 + for i in range(iteration_count): + await asyncio.gather( + service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={f'sk{i}-1': f'v{i}-1'}, + ), + service.merge_state( + app_name='my_app', + user_id='user', + session_id='s1', + delta={f'sk{i}-2': f'v{i}-2'}, + ), + ) + + final = await service.get_session( + app_name='my_app', user_id='user', session_id='s1' + ) + # Both independent keys from every iteration must be present (no lost + # update), and no event row was created. + for i in range(iteration_count): + assert final.state.get(f'sk{i}-1') == f'v{i}-1' + assert final.state.get(f'sk{i}-2') == f'v{i}-2' + assert final.events == [] + finally: + await service.close() + + @pytest.mark.asyncio async def test_append_event_raises_if_app_state_row_missing(): service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 7f3cb61a05..1cbdbd77d2 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -656,6 +656,18 @@ async def test_initialize_with_project_location_and_api_key_error(): ) +@pytest.mark.asyncio +async def test_merge_state_not_implemented(): + session_service = mock_vertex_ai_session_service() + with pytest.raises(NotImplementedError, match='merge_state'): + await session_service.merge_state( + app_name='123', + user_id='user', + session_id='1', + delta={'k': 'v'}, + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_get_session_returns_none_when_invalid_argument(