Collaborative session implementation, 1
This commit is contained in:
94
synctoy/data_model.py
Normal file
94
synctoy/data_model.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Literal, NewType
|
||||
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
from synctoy.state_machine import StateMachine
|
||||
|
||||
|
||||
class Type(StrEnum):
|
||||
Null = "null"
|
||||
User = "user"
|
||||
Message = "message"
|
||||
|
||||
ObjectId = NewType("ObjectId", str)
|
||||
|
||||
class RNull(BaseModel):
|
||||
type: Literal[Type.Null] = Type.Null
|
||||
id: ObjectId
|
||||
|
||||
def equivalent_to(self, other: StateMachine):
|
||||
return self == other
|
||||
|
||||
def can_transition_from(self, old: StateMachine):
|
||||
return True
|
||||
|
||||
def can_transition_to(self, new: StateMachine):
|
||||
return True
|
||||
|
||||
class RUser(BaseModel):
|
||||
type: Literal[Type.User] = Type.User
|
||||
id: ObjectId
|
||||
name: str
|
||||
|
||||
def equivalent_to(self, other: StateMachine):
|
||||
return self == other
|
||||
|
||||
def can_transition_from(self, old: StateMachine):
|
||||
if isinstance(old, RNull):
|
||||
return True
|
||||
return False
|
||||
|
||||
def can_transition_to(self, new: StateMachine):
|
||||
return False
|
||||
|
||||
class RMessage(BaseModel):
|
||||
type: Literal[Type.Message] = Type.Message
|
||||
id: ObjectId
|
||||
sender: ObjectId
|
||||
content: str
|
||||
|
||||
def equivalent_to(self, other: StateMachine):
|
||||
return self == other
|
||||
|
||||
def can_transition_from(self, old: StateMachine):
|
||||
if isinstance(old, RNull):
|
||||
return True
|
||||
if isinstance(old, RMessage):
|
||||
return True
|
||||
return False
|
||||
|
||||
def can_transition_to(self, new: StateMachine):
|
||||
if isinstance(new, RMessage):
|
||||
return new.sender == self.sender
|
||||
return True
|
||||
|
||||
NonNullRecord = RUser | RMessage
|
||||
Record = RNull | NonNullRecord
|
||||
|
||||
|
||||
|
||||
def as_state_machine(r: Record) -> StateMachine:
|
||||
return r
|
||||
|
||||
|
||||
class ConditionType(StrEnum):
|
||||
True_ = "true"
|
||||
EquivalentTo = "equivalent_to"
|
||||
|
||||
|
||||
class CTrue(BaseModel):
|
||||
type: Literal[ConditionType.True_] = ConditionType.True_
|
||||
|
||||
def is_met(self, old: Record, new: Record):
|
||||
return True
|
||||
|
||||
class CEquivalentTo(BaseModel):
|
||||
type: Literal[ConditionType.EquivalentTo] = ConditionType.EquivalentTo
|
||||
|
||||
expected: Record
|
||||
|
||||
def is_met(self, old: Record, new: Record):
|
||||
return old.equivalent_to(self.expected)
|
||||
|
||||
Condition = CTrue | CEquivalentTo
|
||||
55
synctoy/main.py
Normal file
55
synctoy/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import json
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.trustedhost import TrustedHostMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.routing import Route
|
||||
|
||||
from synctoy.data_model import ObjectId, NonNullRecord, Record
|
||||
from synctoy.snapshot import SAll
|
||||
from synctoy.store import Store, Update
|
||||
|
||||
|
||||
store = Store()
|
||||
|
||||
_TIMEOUT = 5.0
|
||||
|
||||
async def snapshot(request: Request):
|
||||
snapshot = store.snapshot(SAll())
|
||||
populated = snapshot.populated_records()
|
||||
return JSONResponse(TypeAdapter(dict[ObjectId, NonNullRecord]).dump_python(populated, mode="json"))
|
||||
|
||||
class UpdateResult(BaseModel):
|
||||
index: int
|
||||
|
||||
async def update(request: Request):
|
||||
value = Update.model_validate_json(await request.body())
|
||||
version = store.update([value])
|
||||
return JSONResponse(UpdateResult(index=version).model_dump(mode="json"))
|
||||
|
||||
async def events(request: Request):
|
||||
start = request.query_params.get("start")
|
||||
if start is None:
|
||||
start = 0
|
||||
else:
|
||||
start = int(start)
|
||||
|
||||
result = store.observe(start_from=start)
|
||||
if len(result) == 0:
|
||||
await store.wait(_TIMEOUT)
|
||||
result = store.observe(start_from=start)
|
||||
|
||||
await asyncio.sleep(1.0) # simulated delay! wow
|
||||
return JSONResponse(TypeAdapter(list[Record]).dump_python(result))
|
||||
|
||||
app = Starlette(debug=True, routes=[
|
||||
Route("/snapshot", snapshot, methods=["GET"]),
|
||||
Route("/update", update, methods=["POST"]),
|
||||
Route("/events", events, methods=["GET"])
|
||||
],
|
||||
)
|
||||
app = CORSMiddleware(app=app, allow_origins=["*"])
|
||||
43
synctoy/snapshot.py
Normal file
43
synctoy/snapshot.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Protocol
|
||||
|
||||
from synctoy.data_model import NonNullRecord, ObjectId, RNull, Record
|
||||
|
||||
|
||||
class Selector(Protocol):
|
||||
def includes_object_id(self, object_id: ObjectId):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SAll(object):
|
||||
def includes_object_id(self, object_id: ObjectId):
|
||||
return True
|
||||
|
||||
class SList(object):
|
||||
def __init__(self, object_ids: list[ObjectId]):
|
||||
self._object_ids = object_ids
|
||||
|
||||
def includes_object_id(self, object_id: ObjectId):
|
||||
return object_id in self._object_ids
|
||||
|
||||
class Snapshot(object):
|
||||
def __init__(self, selector: Selector, data: dict[ObjectId, NonNullRecord]):
|
||||
self._selector = selector
|
||||
self._data = data
|
||||
|
||||
def __getitem__(self, object_id: ObjectId):
|
||||
if not self._selector.includes_object_id(object_id):
|
||||
raise KeyError(f"{object_id} was not selected and cannot be viewed")
|
||||
|
||||
return self._data.get(object_id) or RNull(id=object_id)
|
||||
|
||||
def __setitem__(self, object_id: ObjectId, value: Record):
|
||||
if not self._selector.includes_object_id(object_id):
|
||||
raise KeyError(f"{object_id} was not selected and cannot be staged")
|
||||
|
||||
if isinstance(value, RNull):
|
||||
self._data.pop(object_id, None)
|
||||
else:
|
||||
self._data[object_id] = value
|
||||
|
||||
def populated_records(self) -> dict[ObjectId, NonNullRecord]:
|
||||
return dict(self._data)
|
||||
12
synctoy/state_machine.py
Normal file
12
synctoy/state_machine.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class StateMachine(Protocol):
|
||||
def equivalent_to(self, other: "StateMachine") -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_transition_from(self, old: "StateMachine") -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_transition_to(self, new: "StateMachine") -> bool:
|
||||
raise NotImplementedError
|
||||
65
synctoy/store.py
Normal file
65
synctoy/store.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import asyncio
|
||||
from typing import NewType, Protocol
|
||||
from pydantic import BaseModel, JsonValue
|
||||
|
||||
from synctoy.data_model import Condition, NonNullRecord, ObjectId, RNull, Record
|
||||
from synctoy.snapshot import SList, Selector, Snapshot
|
||||
|
||||
class Update(BaseModel):
|
||||
condition: Condition
|
||||
new: Record
|
||||
|
||||
class PreconditionException(Exception):
|
||||
pass
|
||||
|
||||
class Store(object):
|
||||
def __init__(self):
|
||||
self._event = asyncio.Event()
|
||||
self._versions: list[Record] = []
|
||||
|
||||
async def wait(self, max_timeout: float):
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), max_timeout)
|
||||
except TimeoutError:
|
||||
return
|
||||
|
||||
def observe(self, start_from: int):
|
||||
return self._versions[start_from:]
|
||||
|
||||
def update(self, updates: list[Update]) -> int:
|
||||
object_ids: set[ObjectId] = set(u.new.id for u in updates)
|
||||
snapshot: Snapshot = self.snapshot(SList(object_ids=list(object_ids)))
|
||||
|
||||
for update in updates:
|
||||
id = update.new.id
|
||||
old_object = snapshot[id]
|
||||
new_object = update.new
|
||||
|
||||
if not old_object.can_transition_to(new_object):
|
||||
raise PreconditionException(f"can't transition from {old_object} to {new_object}")
|
||||
if not new_object.can_transition_from(old_object):
|
||||
raise PreconditionException(f"can't transition from {old_object} to {new_object}")
|
||||
|
||||
if not update.condition.is_met(old_object, new_object):
|
||||
raise PreconditionException(f"failed condition: {update.condition}")
|
||||
|
||||
snapshot[id] = new_object
|
||||
|
||||
for new_update in updates:
|
||||
self._versions.append(new_update.new)
|
||||
self._event.set()
|
||||
self._event = asyncio.Event()
|
||||
return len(self._versions)
|
||||
|
||||
def snapshot(self, selector: Selector) -> Snapshot:
|
||||
pre_snapshot: dict[ObjectId, Record] = {}
|
||||
for row in self._versions:
|
||||
if not selector.includes_object_id(row.id):
|
||||
continue
|
||||
|
||||
pre_snapshot[row.id] = row
|
||||
filtered: dict[ObjectId, NonNullRecord] = {
|
||||
id: row for id, row in pre_snapshot.items() if not isinstance(row, RNull)
|
||||
}
|
||||
|
||||
return Snapshot(selector, filtered)
|
||||
Reference in New Issue
Block a user