Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_set_state_for_new_state(self):
state_manager = ActorStateManager(self._fake_actor)
_run(state_manager.set_state('state1', 'value1'))
state = state_manager._state_change_tracker['state1']
self.assertEqual(StateChangeKind.add, state.change_kind)
self.assertEqual('value1', state.value)
def test_add_state(self):
state_manager = ActorStateManager(self._fake_actor)
# Add first 'state1'
added = _run(state_manager.try_add_state('state1', 'value1'))
self.assertTrue(added)
state = state_manager._state_change_tracker['state1']
self.assertEqual('value1', state.value)
self.assertEqual(StateChangeKind.add, state.change_kind)
# Add 'state1' again
added = _run(state_manager.try_add_state('state1', 'value1'))
self.assertFalse(added)
async def try_add_state(self, state_name: str, value: T) -> bool:
if state_name in self._state_change_tracker:
state_metadata = self._state_change_tracker[state_name]
if state_metadata.change_kind == StateChangeKind.remove:
self._state_change_tracker[state_name] = \
StateMetadata(value, StateChangeKind.update)
return True
return False
existed = await self._actor.runtime_ctx.state_provider.contains_state(
self._type_name, self._actor.id.id, state_name)
if not existed:
return False
self._state_change_tracker[state_name] = StateMetadata(value, StateChangeKind.add)
return True
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
"""
import io
from typing import Any, List, Type, Tuple
from dapr.actor.runtime.state_change import StateChangeKind, ActorStateChange
from dapr.clients import DaprActorClientBase
from dapr.serializers import Serializer, DefaultJSONSerializer
# Mapping StateChangeKind to Dapr State Operation
_MAP_CHANGE_KIND_TO_OPERATION = {
StateChangeKind.remove: b'delete',
StateChangeKind.add: b'upsert',
StateChangeKind.update: b'upsert',
}
class StateProvider:
def __init__(
self,
actor_client: DaprActorClientBase,
state_serializer: Serializer = None):
self._state_client = actor_client
self._state_serializer = state_serializer or DefaultJSONSerializer()
async def try_load_state(
self, actor_type: str, actor_id: str,
state_name: str, state_type: Type[Any] = object) -> Tuple[bool, Any]:
def is_state_marked_for_remove(self, state_name: str) -> bool:
return state_name in self._state_change_tracker and \
self._state_change_tracker[state_name].change_kind == StateChangeKind.remove
async def save_state(self) -> None:
if len(self._state_change_tracker) == 0:
return
state_changes = []
states_to_remove = []
for state_name, state_metadata in self._state_change_tracker.items():
if state_metadata.change_kind == StateChangeKind.none:
continue
state_changes.append(ActorStateChange(
state_name, state_metadata.value,
state_metadata.change_kind))
if state_metadata.change_kind == StateChangeKind.remove:
states_to_remove.append(state_name)
# Mark the states as unmodified so that tracking for next invocation is done correctly.
state_metadata.change_kind = StateChangeKind.none
if len(state_changes) > 0:
await self._actor.runtime_ctx.state_provider.save_state(
self._type_name, self._actor.id.id, state_changes)
for state_name in states_to_remove:
self._state_change_tracker.pop(state_name, None)
Licensed under the MIT License.
"""
import io
from typing import Any, List, Type, Tuple
from dapr.actor.runtime.state_change import StateChangeKind, ActorStateChange
from dapr.clients import DaprActorClientBase
from dapr.serializers import Serializer, DefaultJSONSerializer
# Mapping StateChangeKind to Dapr State Operation
_MAP_CHANGE_KIND_TO_OPERATION = {
StateChangeKind.remove: b'delete',
StateChangeKind.add: b'upsert',
StateChangeKind.update: b'upsert',
}
class StateProvider:
def __init__(
self,
actor_client: DaprActorClientBase,
state_serializer: Serializer = None):
self._state_client = actor_client
self._state_serializer = state_serializer or DefaultJSONSerializer()
async def try_load_state(
self, actor_type: str, actor_id: str,
state_name: str, state_type: Type[Any] = object) -> Tuple[bool, Any]:
raw_state_value = await self._state_client.get_state(actor_type, actor_id, state_name)
if (not raw_state_value) or len(raw_state_value) == 0: