basis/components/cluster/__init__.py
apeters 1d204f26b8 pre-Korves.Net
Signed-off-by: apeters <apeters@korves.net>
2025-05-21 08:05:07 +00:00

625 lines
25 KiB
Python

import asyncio
import base64
import json
import os
import random
import re
import socket
import sys
import zlib
from config import defaults
from contextlib import closing, suppress
from components.cluster.exceptions import (
IncompleteClusterResponses,
DistLockCancelled,
UnknownPeer,
ZombiePeer,
)
from components.cluster.cli import cli_processor
from components.models.cluster import (
CritErrors,
Role,
FilePath,
validate_call,
ConnectionStatus,
)
from components.cluster.leader import elect_leader
from components.cluster.peers import Peers
from components.cluster.monitor import Monitor
from components.cluster.ssl import get_ssl_context
from components.logs import logger
from components.database import *
from components.utils.cryptography import dict_digest_sha1
from components.utils.datetimes import ntime_utc_now
from components.utils import ensure_list, is_path_within_cwd, chunk_string
class Cluster:
def __init__(self, peers, port):
self.locks = dict()
self.port = port
self.receiving = asyncio.Condition()
self.tickets = dict()
self._partial_tickets = dict()
self.monitor = Monitor(self)
self.server_limit = 104857600 # 100 MiB
self.peers = Peers(peers)
self.sending = asyncio.Lock()
self._session_patched_tables = dict()
def _release_tables(self, tables):
for t in tables:
with suppress(RuntimeError):
self.locks[t]["lock"].release()
self.locks[t]["ticket"] = None
async def read_command(self, reader: asyncio.StreamReader) -> tuple[str, str, dict]:
bytes_to_read = int.from_bytes(await reader.readexactly(4), "big")
input_bytes = await reader.readexactly(bytes_to_read)
input_decoded = input_bytes.strip().decode("utf-8")
data, _, meta = input_decoded.partition(" :META ")
ticket, _, cmd = data.partition(" ")
patterns = [
r"NAME (?P<name>\S+)",
r"SWARM (?P<swarm>\S+)",
r"STARTED (?P<started>\S+)",
r"LEADER (?P<leader>\S+)",
]
match = re.search(" ".join(patterns), meta)
meta_dict = match.groupdict()
name = meta_dict["name"]
if not name in self.peers.remotes:
raise UnknownPeer(name)
self.peers.remotes[name].leader = meta_dict["leader"]
self.peers.remotes[name].started = float(meta_dict["started"])
self.peers.remotes[name].swarm = meta_dict["swarm"]
msg = cmd[:150] + (cmd[150:] and "...")
logger.debug(f"[← Receiving from {name}][{ticket}] - {msg}")
return ticket, cmd, meta_dict
async def incoming_handler(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
monitor_init = False
raddr, *_ = writer.get_extra_info("peername")
socket, *_ = writer.get_extra_info("socket").getsockname()
if socket and raddr in defaults.CLUSTER_CLI_BINDINGS:
return await cli_processor((reader, writer))
if raddr not in self.peers.remote_ips():
raise UnknownPeer(raddr)
while True:
try:
ticket, cmd, peer_meta = await self.read_command(reader)
if not monitor_init:
self.peers.remotes[peer_meta["name"]].streams._in = (
reader,
writer,
)
await self.monitor.start_peer_monitoring(peer_meta["name"])
monitor_init = True
except (asyncio.exceptions.IncompleteReadError, ConnectionResetError):
break
except TimeoutError as e:
if str(e) == "SSL shutdown timed out":
break
raise
except ZombiePeer:
await self.send_command(
CritErrors.ZOMBIE.response, peer_meta["name"], ticket=ticket
)
break
try:
if not self.peers.local.leader and not any(
map(
lambda s: cmd.startswith(s),
["ACK", "STATUS", "FULLTABLE", "INIT"],
)
):
await self.send_command(
CritErrors.NOT_READY.response, peer_meta["name"], ticket=ticket
)
elif cmd.startswith("PATCHTABLE") or cmd.startswith("FULLTABLE"):
_, _, payload = cmd.partition(" ")
table_w_hash, table_payload = payload.split(" ")
table, table_digest = table_w_hash.split("@")
db_params = evaluate_db_params(ticket)
async with TinyDB(**db_params) as db:
if not table in db.tables():
await self.send_command(
CritErrors.NO_SUCH_TABLE.response,
peer_meta["name"],
ticket=ticket,
)
else:
if not ticket in self._session_patched_tables:
self._session_patched_tables[ticket] = set()
self._session_patched_tables[ticket].add(table)
try:
if cmd.startswith("PATCHTABLE"):
table_data = {
doc.doc_id: doc for doc in db.table(table).all()
}
errors = []
local_table_digest = dict_digest_sha1(table_data)
if local_table_digest != table_digest:
await self.send_command(
CritErrors.TABLE_HASH_MISMATCH.response,
peer_meta["name"],
ticket=ticket,
)
continue
diff = json.loads(base64.b64decode(table_payload))
for doc_id, docs in diff["changed"].items():
a, b = docs
c = db.table(table).get(doc_id=doc_id)
if c != a:
await self.send_command(
CritErrors.DOC_MISMATCH.response,
peer_meta["name"],
ticket=ticket,
)
break
db.table(table).upsert(
Document(b, doc_id=doc_id)
)
else: # if no break occured, continue
for doc_id, doc in diff["added"].items():
db.table(table).insert(
Document(doc, doc_id=doc_id)
)
db.table(table).remove(
Query().id.one_of(
[
doc["id"]
for doc in diff["removed"].values()
]
)
)
elif cmd.startswith("FULLTABLE"):
insert_data = json.loads(
base64.b64decode(table_payload)
)
db.table(table).truncate()
for doc_id, doc in insert_data.items():
db.table(table).insert(
Document(doc, doc_id=doc_id)
)
await self.send_command(
"ACK", peer_meta["name"], ticket=ticket
)
except Exception as e:
await self.send_command(
CritErrors.CANNOT_APPLY.response,
peer_meta["name"],
ticket=ticket,
)
continue
elif cmd == "COMMIT":
if not ticket in self._session_patched_tables:
await self.send_command(
CritErrors.NOTHING_TO_COMMIT.response,
peer_meta["name"],
ticket=ticket,
)
continue
commit_tables = self._session_patched_tables[ticket]
del self._session_patched_tables[ticket]
await dbcommit(commit_tables, ticket)
await self.send_command("ACK", peer_meta["name"], ticket=ticket)
elif cmd == "STATUS" or cmd == "INIT":
await self.send_command("ACK", peer_meta["name"], ticket=ticket)
elif cmd == "BYE":
self.peers.remotes[peer_meta["name"]] = self.peers.remotes[
peer_meta["name"]
].reset()
elif cmd.startswith("FILEGET"):
_, _, payload = cmd.partition(" ")
start, end, file = payload.split(" ")
if not is_path_within_cwd(file) or not os.path.exists(file):
await self.send_command(
CritErrors.INVALID_FILE_PATH.response,
peer_meta["name"],
ticket=ticket,
)
continue
if os.stat(file).st_size < int(start):
await self.send_command(
CritErrors.START_BEHIND_FILE_END.response,
peer_meta["name"],
ticket=ticket,
)
continue
with open(file, "rb") as f:
f.seek(int(start))
compressed_data = zlib.compress(f.read(int(end)))
compressed_data_encoded = base64.b64encode(
compressed_data
).decode("utf-8")
chunks = chunk_string(f"{file} {compressed_data_encoded}")
for idx, c in enumerate(chunks, 1):
await self.send_command(
f"PACK CHUNKED {idx} {len(chunks)} {c}",
peer_meta["name"],
ticket=ticket,
)
elif cmd.startswith("UNLOCK"):
_, _, tables_str = cmd.partition(" ")
tables = tables_str.split(",")
for t in tables:
with suppress(RuntimeError):
self.locks[t]["lock"].release()
self.locks[t]["ticket"] = None
await self.send_command("ACK", peer_meta["name"], ticket=ticket)
elif cmd.startswith("LOCK"):
_, _, tables_str = cmd.partition(" ")
tables = tables_str.split(",")
if peer_meta["swarm"] != self.peers.local.swarm:
await self.send_command(
CritErrors.PEERS_MISMATCH.response,
peer_meta["name"],
ticket=ticket,
)
continue
try:
for t in tables:
if t not in self.locks:
self.locks[t] = {
"lock": asyncio.Lock(),
"ticket": None,
}
await asyncio.wait_for(
self.locks[t]["lock"].acquire(),
0.3 + random.uniform(0.1, 0.5),
)
self.locks[t]["ticket"] = ticket
except TimeoutError:
await self.send_command(
"ACK BUSY", peer_meta["name"], ticket=ticket
)
else:
await self.send_command("ACK", peer_meta["name"], ticket=ticket)
elif cmd.startswith("PACK"):
if not ticket in self._partial_tickets:
self._partial_tickets[ticket] = []
_, _, payload = cmd.partition(" ")
_, idx, total, partial_data = payload.split(" ", 3)
self._partial_tickets[ticket].append(partial_data)
if idx == total:
self.tickets[ticket].add(
(peer_meta["name"], "".join(self._partial_tickets[ticket]))
)
elif cmd.startswith("ACK"):
_, _, payload = cmd.partition(" ")
if ticket in self.tickets:
self.tickets[ticket].add((peer_meta["name"], payload))
async with self.receiving:
self.receiving.notify_all()
except ConnectionResetError:
break
async def send_command(self, cmd, peers, ticket: str | None = None):
if not ticket:
ticket = str(ntime_utc_now())
if peers == "*":
peers = self.peers.remotes.keys()
else:
peers = ensure_list(peers)
if ticket not in self.tickets:
self.tickets[ticket] = set()
successful_receivers = []
for name in peers:
async with self.peers.remotes[name].sending_lock:
con, status = await self.peers.remotes[name].connect()
if con:
reader, writer = con
buffer_data = [
ticket,
cmd,
":META",
f"NAME {self.peers.local.name}",
"SWARM {swarm}".format(
swarm=self.peers.local.swarm or "?CONFUSED"
),
f"STARTED {self.peers.local.started}",
"LEADER {leader}".format(
leader=self.peers.local.leader or "?CONFUSED"
),
]
buffer_bytes = " ".join(buffer_data).encode("utf-8")
writer.write(len(buffer_bytes).to_bytes(4, "big"))
writer.write(buffer_bytes)
await writer.drain()
msg = cmd[:150] + (cmd[150:] and "...")
logger.debug(f"[→ Sending to {name}][{ticket}] - {msg}")
successful_receivers.append(name)
else:
logger.warning(f"Cannot send to peer {name} - {status}")
return ticket, successful_receivers
async def release(self, tables: list = ["main"]) -> str:
CTX_TICKET.reset(self._ctx_ticket)
if self.peers.local.role == Role.FOLLOWER:
try:
ticket, receivers = await self.send_command(
f"UNLOCK {','.join(tables)}",
self.peers.local.leader,
)
async with self.receiving:
ret, responses = await self.await_receivers(
ticket, receivers, raise_err=True
)
except IncompleteClusterResponses:
self._release_tables(tables)
raise DistLockCancelled(
"Leader did not respond properly to unlock request"
)
self._release_tables(tables)
async def acquire_lock(self, tables: list = ["main"]) -> str:
try:
ticket = str(ntime_utc_now())
for t in tables:
if t not in self.locks:
self.locks[t] = {
"lock": asyncio.Lock(),
"ticket": None,
}
await asyncio.wait_for(self.locks[t]["lock"].acquire(), 20.0)
self.locks[t]["ticket"] = ticket
except TimeoutError:
raise DistLockCancelled("Unable to acquire local lock")
if self.peers.local.role == Role.FOLLOWER:
try:
if not self.peers.local.leader:
self._release_tables(tables)
raise IncompleteClusterResponses("Cluster is not ready")
ticket, receivers = await self.send_command(
f"LOCK {','.join(tables)}", self.peers.local.leader
)
async with self.receiving:
result, responses = await self.await_receivers(
ticket, receivers, raise_err=False
)
if "BUSY" in responses:
self._release_tables(tables)
return await self.acquire_lock(tables)
elif not result:
self._release_tables(tables)
if CritErrors.PEERS_MISMATCH in responses:
raise DistLockCancelled(
"Leader rejected lock due to member inconsistency"
)
else:
raise DistLockCancelled(
"Cannot acquire lock from leader, trying a re-election"
)
except asyncio.CancelledError:
self._release_tables(tables)
raise DistLockCancelled("Application raised CancelledError")
self._ctx_ticket = CTX_TICKET.set(ticket)
@validate_call
async def request_files(
self,
files: FilePath | list[FilePath],
peers: str | list,
start: int = 0,
end: int = -1,
):
assert self.locks["files"]["lock"].locked()
for file in ensure_list(files):
try:
if not is_path_within_cwd(file):
logger.error(f"File not within cwd: {file}")
continue
for peer in ensure_list(peers):
peer_start = start
peer_end = end
assert peer in self.peers.remotes
if peer_start == -1:
if os.path.exists(f"peer_files/{peer}/{file}"):
peer_start = os.stat(f"peer_files/{peer}/{file}").st_size
else:
peer_start = 0
ticket, receivers = await self.send_command(
f"FILEGET {peer_start} {peer_end} {file}", peer
)
async with self.receiving:
_, response = await self.await_receivers(
ticket, receivers, raise_err=True
)
for r in response:
r_file, r_data = r.split(" ")
assert FilePath(r_file) == file
file_dest = f"peer_files/{peer}/{file}"
os.makedirs(os.path.dirname(file_dest), exist_ok=True)
payload = zlib.decompress(base64.b64decode(r_data))
if os.path.exists(file_dest):
mode = "r+b"
else:
mode = "w+b"
with open(file_dest, mode) as f:
f.seek(peer_start)
f.write(payload)
except IncompleteClusterResponses:
logger.error(f"Sending command to peers '{peers}' failed")
raise
except Exception as e:
logger.error(f"Unhandled error: {e}")
raise
async def run(self, shutdown_trigger) -> None:
server = await asyncio.start_server(
self.incoming_handler,
self.peers.local._all_bindings_as_str,
self.port,
ssl=get_ssl_context("server"),
limit=self.server_limit,
)
self._shutdown = shutdown_trigger
logger.info(
f"Listening on {self.port} on address {' and '.join(self.peers.local._all_bindings_as_str)}..."
)
async with server:
for name in self.peers.remotes:
ip, status = self.peers.remotes[name]._eval_ip()
if ip:
con, status = await self.peers.remotes[name].connect(ip)
if con:
ticket, receivers = await self.send_command("INIT", name)
async with self.receiving:
ret, responses = await self.await_receivers(
ticket, receivers, raise_err=False
)
if CritErrors.ZOMBIE in responses:
logger.critical(
f"Peer {name} has not yet disconnected a previous session: {status}"
)
shutdown_trigger.set()
break
else:
logger.debug(f"Not sending INIT to peer {name}: {status}")
else:
logger.debug(f"Not sending INIT to peer {name}: {status}")
t = asyncio.create_task(self.monitor._ticket_worker(), name="tickets")
self.monitor.tasks.add(t)
t.add_done_callback(self.monitor.tasks.discard)
try:
binds = [s.getsockname()[0] for s in server.sockets]
for local_rbind in self.peers.local._bindings_as_str:
if local_rbind not in binds:
logger.critical(
f"Could not bind requested address {local_rbind}"
)
shutdown_trigger.set()
break
else:
await shutdown_trigger.wait()
except asyncio.CancelledError:
if self.peers.get_established():
await self.send_command("BYE", "*")
[t.cancel() for t in self.monitor.tasks]
async def await_receivers(
self,
ticket,
receivers,
raise_err: bool,
timeout: float = defaults.CLUSTER_PEERS_TIMEOUT * len(defaults.CLUSTER_PEERS),
):
errors = []
missing_receivers = []
assert self.receiving.locked()
if timeout >= defaults.CLUSTER_PEERS_TIMEOUT * len(defaults.CLUSTER_PEERS) + 10:
raise ValueError("Timeout is too high")
try:
while not all(
r in [peer for peer, _ in self.tickets[ticket]] for r in receivers
):
await asyncio.wait_for(self.receiving.wait(), timeout)
except TimeoutError:
missing_receivers = [
r
for r in receivers
if r not in [peer for peer, _ in self.tickets[ticket]]
]
errors.append(
f"Timeout waiting for receviers: {', '.join(missing_receivers)}"
)
finally:
responses = []
for peer, response in self.tickets[ticket]:
if response in CritErrors._value2member_map_:
responses.append(CritErrors(response))
errors.append(f"CRIT response from {peer}: {CritErrors(response)}")
else:
responses.append(response)
if not missing_receivers and len(responses) != len(receivers):
errors.append("Unplausible amount of responses for ticket")
self.tickets.pop(ticket, None)
if errors:
logger.error("\n".join(errors))
if raise_err:
raise IncompleteClusterResponses("\n".join(errors))
else:
logger.success(f"{ticket} OK")
return not errors, responses