You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
532 lines
20 KiB
Python
532 lines
20 KiB
Python
"""A base class session manager."""
|
|
|
|
# Copyright (c) Jupyter Development Team.
|
|
# Distributed under the terms of the Modified BSD License.
|
|
import os
|
|
import pathlib
|
|
import uuid
|
|
from typing import Any, NewType, Optional, Union, cast
|
|
|
|
KernelName = NewType("KernelName", str)
|
|
ModelName = NewType("ModelName", str)
|
|
|
|
try:
|
|
import sqlite3
|
|
except ImportError:
|
|
# fallback on pysqlite2 if Python was build without sqlite
|
|
from pysqlite2 import dbapi2 as sqlite3 # type:ignore[no-redef]
|
|
|
|
from dataclasses import dataclass, fields
|
|
|
|
from jupyter_core.utils import ensure_async
|
|
from tornado import web
|
|
from traitlets import Instance, TraitError, Unicode, validate
|
|
from traitlets.config.configurable import LoggingConfigurable
|
|
|
|
from jupyter_server.traittypes import InstanceFromClasses
|
|
|
|
|
|
class KernelSessionRecordConflict(Exception):
|
|
"""Exception class to use when two KernelSessionRecords cannot
|
|
merge because of conflicting data.
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class KernelSessionRecord:
|
|
"""A record object for tracking a Jupyter Server Kernel Session.
|
|
|
|
Two records that share a session_id must also share a kernel_id, while
|
|
kernels can have multiple session (and thereby) session_ids
|
|
associated with them.
|
|
"""
|
|
|
|
session_id: Optional[str] = None
|
|
kernel_id: Optional[str] = None
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
"""Whether a record equals another."""
|
|
if isinstance(other, KernelSessionRecord):
|
|
condition1 = self.kernel_id and self.kernel_id == other.kernel_id
|
|
condition2 = all(
|
|
[
|
|
self.session_id == other.session_id,
|
|
self.kernel_id is None or other.kernel_id is None,
|
|
]
|
|
)
|
|
if any([condition1, condition2]):
|
|
return True
|
|
# If two records share session_id but have different kernels, this is
|
|
# and ill-posed expression. This should never be true. Raise an exception
|
|
# to inform the user.
|
|
if all(
|
|
[
|
|
self.session_id,
|
|
self.session_id == other.session_id,
|
|
self.kernel_id != other.kernel_id,
|
|
]
|
|
):
|
|
msg = (
|
|
"A single session_id can only have one kernel_id "
|
|
"associated with. These two KernelSessionRecords share the same "
|
|
"session_id but have different kernel_ids. This should "
|
|
"not be possible and is likely an issue with the session "
|
|
"records."
|
|
)
|
|
raise KernelSessionRecordConflict(msg)
|
|
return False
|
|
|
|
def update(self, other: "KernelSessionRecord") -> None:
|
|
"""Updates in-place a kernel from other (only accepts positive updates"""
|
|
if not isinstance(other, KernelSessionRecord):
|
|
msg = "'other' must be an instance of KernelSessionRecord." # type:ignore[unreachable]
|
|
raise TypeError(msg)
|
|
|
|
if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
|
|
msg = "Could not update the record from 'other' because the two records conflict."
|
|
raise KernelSessionRecordConflict(msg)
|
|
|
|
for field in fields(self):
|
|
if hasattr(other, field.name) and getattr(other, field.name):
|
|
setattr(self, field.name, getattr(other, field.name))
|
|
|
|
|
|
class KernelSessionRecordList:
|
|
"""An object for storing and managing a list of KernelSessionRecords.
|
|
|
|
When adding a record to the list, the KernelSessionRecordList
|
|
first checks if the record already exists in the list. If it does,
|
|
the record will be updated with the new information; otherwise,
|
|
it will be appended.
|
|
"""
|
|
|
|
_records: list[KernelSessionRecord]
|
|
|
|
def __init__(self, *records: KernelSessionRecord):
|
|
"""Initialize a record list."""
|
|
self._records = []
|
|
for record in records:
|
|
self.update(record)
|
|
|
|
def __str__(self):
|
|
"""The string representation of a record list."""
|
|
return str(self._records)
|
|
|
|
def __contains__(self, record: Union[KernelSessionRecord, str]) -> bool:
|
|
"""Search for records by kernel_id and session_id"""
|
|
if isinstance(record, KernelSessionRecord) and record in self._records:
|
|
return True
|
|
|
|
if isinstance(record, str):
|
|
for r in self._records:
|
|
if record in [r.session_id, r.kernel_id]:
|
|
return True
|
|
return False
|
|
|
|
def __len__(self):
|
|
"""The length of the record list."""
|
|
return len(self._records)
|
|
|
|
def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord:
|
|
"""Return a full KernelSessionRecord from a session_id, kernel_id, or
|
|
incomplete KernelSessionRecord.
|
|
"""
|
|
if isinstance(record, str):
|
|
for r in self._records:
|
|
if record in (r.kernel_id, r.session_id):
|
|
return r
|
|
elif isinstance(record, KernelSessionRecord):
|
|
for r in self._records:
|
|
if record == r:
|
|
return record
|
|
msg = f"{record} not found in KernelSessionRecordList."
|
|
raise ValueError(msg)
|
|
|
|
def update(self, record: KernelSessionRecord) -> None:
|
|
"""Update a record in-place or append it if not in the list."""
|
|
try:
|
|
idx = self._records.index(record)
|
|
self._records[idx].update(record)
|
|
except ValueError:
|
|
self._records.append(record)
|
|
|
|
def remove(self, record: KernelSessionRecord) -> None:
|
|
"""Remove a record if its found in the list. If it's not found,
|
|
do nothing.
|
|
"""
|
|
if record in self._records:
|
|
self._records.remove(record)
|
|
|
|
|
|
class SessionManager(LoggingConfigurable):
|
|
"""A session manager."""
|
|
|
|
database_filepath = Unicode(
|
|
default_value=":memory:",
|
|
help=(
|
|
"The filesystem path to SQLite Database file "
|
|
"(e.g. /path/to/session_database.db). By default, the session "
|
|
"database is stored in-memory (i.e. `:memory:` setting from sqlite3) "
|
|
"and does not persist when the current Jupyter Server shuts down."
|
|
),
|
|
).tag(config=True)
|
|
|
|
@validate("database_filepath")
|
|
def _validate_database_filepath(self, proposal):
|
|
"""Validate a database file path."""
|
|
value = proposal["value"]
|
|
if value == ":memory:":
|
|
return value
|
|
path = pathlib.Path(value)
|
|
if path.exists():
|
|
# Verify that the database path is not a directory.
|
|
if path.is_dir():
|
|
msg = "`database_filepath` expected a file path, but the given path is a directory."
|
|
raise TraitError(msg)
|
|
# Verify that database path is an SQLite 3 Database by checking its header.
|
|
with open(value, "rb") as f:
|
|
header = f.read(100)
|
|
|
|
if not header.startswith(b"SQLite format 3") and header != b"":
|
|
msg = "The given file is not an SQLite database file."
|
|
raise TraitError(msg)
|
|
return value
|
|
|
|
kernel_manager = Instance("jupyter_server.services.kernels.kernelmanager.MappingKernelManager")
|
|
contents_manager = InstanceFromClasses(
|
|
[
|
|
"jupyter_server.services.contents.manager.ContentsManager",
|
|
"notebook.services.contents.manager.ContentsManager",
|
|
]
|
|
)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""Initialize a record list."""
|
|
super().__init__(*args, **kwargs)
|
|
self._pending_sessions = KernelSessionRecordList()
|
|
|
|
# Session database initialized below
|
|
_cursor = None
|
|
_connection = None
|
|
_columns = {"session_id", "path", "name", "type", "kernel_id"}
|
|
|
|
@property
|
|
def cursor(self):
|
|
"""Start a cursor and create a database called 'session'"""
|
|
if self._cursor is None:
|
|
self._cursor = self.connection.cursor()
|
|
self._cursor.execute(
|
|
"""CREATE TABLE IF NOT EXISTS session
|
|
(session_id, path, name, type, kernel_id)"""
|
|
)
|
|
return self._cursor
|
|
|
|
@property
|
|
def connection(self):
|
|
"""Start a database connection"""
|
|
if self._connection is None:
|
|
# Set isolation level to None to autocommit all changes to the database.
|
|
self._connection = sqlite3.connect(self.database_filepath, isolation_level=None)
|
|
self._connection.row_factory = sqlite3.Row
|
|
return self._connection
|
|
|
|
def close(self):
|
|
"""Close the sqlite connection"""
|
|
if self._cursor is not None:
|
|
self._cursor.close()
|
|
self._cursor = None
|
|
|
|
def __del__(self):
|
|
"""Close connection once SessionManager closes"""
|
|
self.close()
|
|
|
|
async def session_exists(self, path):
|
|
"""Check to see if the session of a given name exists"""
|
|
exists = False
|
|
self.cursor.execute("SELECT * FROM session WHERE path=?", (path,))
|
|
row = self.cursor.fetchone()
|
|
if row is not None:
|
|
# Note, although we found a row for the session, the associated kernel may have
|
|
# been culled or died unexpectedly. If that's the case, we should delete the
|
|
# row, thereby terminating the session. This can be done via a call to
|
|
# row_to_model that tolerates that condition. If row_to_model returns None,
|
|
# we'll return false, since, at that point, the session doesn't exist anyway.
|
|
model = await self.row_to_model(row, tolerate_culled=True)
|
|
if model is not None:
|
|
exists = True
|
|
return exists
|
|
|
|
def new_session_id(self) -> str:
|
|
"""Create a uuid for a new session"""
|
|
return str(uuid.uuid4())
|
|
|
|
async def create_session(
|
|
self,
|
|
path: Optional[str] = None,
|
|
name: Optional[ModelName] = None,
|
|
type: Optional[str] = None,
|
|
kernel_name: Optional[KernelName] = None,
|
|
kernel_id: Optional[str] = None,
|
|
) -> dict[str, Any]:
|
|
"""Creates a session and returns its model
|
|
|
|
Parameters
|
|
----------
|
|
name: ModelName(str)
|
|
Usually the model name, like the filename associated with current
|
|
kernel.
|
|
"""
|
|
session_id = self.new_session_id()
|
|
record = KernelSessionRecord(session_id=session_id)
|
|
self._pending_sessions.update(record)
|
|
if kernel_id is not None and kernel_id in self.kernel_manager:
|
|
pass
|
|
else:
|
|
kernel_id = await self.start_kernel_for_session(
|
|
session_id, path, name, type, kernel_name
|
|
)
|
|
record.kernel_id = kernel_id
|
|
self._pending_sessions.update(record)
|
|
result = await self.save_session(
|
|
session_id, path=path, name=name, type=type, kernel_id=kernel_id
|
|
)
|
|
self._pending_sessions.remove(record)
|
|
return cast(dict[str, Any], result)
|
|
|
|
def get_kernel_env(
|
|
self, path: Optional[str], name: Optional[ModelName] = None
|
|
) -> dict[str, str]:
|
|
"""Return the environment variables that need to be set in the kernel
|
|
|
|
Parameters
|
|
----------
|
|
path : str
|
|
the url path for the given session.
|
|
name: ModelName(str), optional
|
|
Here the name is likely to be the name of the associated file
|
|
with the current kernel at startup time.
|
|
"""
|
|
if name is not None:
|
|
cwd = self.kernel_manager.cwd_for_path(path)
|
|
path = os.path.join(cwd, name)
|
|
assert isinstance(path, str)
|
|
return {**os.environ, "JPY_SESSION_NAME": path}
|
|
|
|
async def start_kernel_for_session(
|
|
self,
|
|
session_id: str,
|
|
path: Optional[str],
|
|
name: Optional[ModelName],
|
|
type: Optional[str],
|
|
kernel_name: Optional[KernelName],
|
|
) -> str:
|
|
"""Start a new kernel for a given session.
|
|
|
|
Parameters
|
|
----------
|
|
session_id : str
|
|
uuid for the session; this method must be given a session_id
|
|
path : str
|
|
the path for the given session - seem to be a session id sometime.
|
|
name : str
|
|
Usually the model name, like the filename associated with current
|
|
kernel.
|
|
type : str
|
|
the type of the session
|
|
kernel_name : str
|
|
the name of the kernel specification to use. The default kernel name will be used if not provided.
|
|
"""
|
|
# allow contents manager to specify kernels cwd
|
|
kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path))
|
|
|
|
kernel_env = self.get_kernel_env(path, name)
|
|
kernel_id = await self.kernel_manager.start_kernel(
|
|
path=kernel_path,
|
|
kernel_name=kernel_name,
|
|
env=kernel_env,
|
|
)
|
|
return cast(str, kernel_id)
|
|
|
|
async def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None):
|
|
"""Saves the items for the session with the given session_id
|
|
|
|
Given a session_id (and any other of the arguments), this method
|
|
creates a row in the sqlite session database that holds the information
|
|
for a session.
|
|
|
|
Parameters
|
|
----------
|
|
session_id : str
|
|
uuid for the session; this method must be given a session_id
|
|
path : str
|
|
the path for the given session
|
|
name : str
|
|
the name of the session
|
|
type : str
|
|
the type of the session
|
|
kernel_id : str
|
|
a uuid for the kernel associated with this session
|
|
|
|
Returns
|
|
-------
|
|
model : dict
|
|
a dictionary of the session model
|
|
"""
|
|
self.cursor.execute(
|
|
"INSERT INTO session VALUES (?,?,?,?,?)",
|
|
(session_id, path, name, type, kernel_id),
|
|
)
|
|
result = await self.get_session(session_id=session_id)
|
|
return result
|
|
|
|
async def get_session(self, **kwargs):
|
|
"""Returns the model for a particular session.
|
|
|
|
Takes a keyword argument and searches for the value in the session
|
|
database, then returns the rest of the session's info.
|
|
|
|
Parameters
|
|
----------
|
|
**kwargs : dict
|
|
must be given one of the keywords and values from the session database
|
|
(i.e. session_id, path, name, type, kernel_id)
|
|
|
|
Returns
|
|
-------
|
|
model : dict
|
|
returns a dictionary that includes all the information from the
|
|
session described by the kwarg.
|
|
"""
|
|
if not kwargs:
|
|
msg = "must specify a column to query"
|
|
raise TypeError(msg)
|
|
|
|
conditions = []
|
|
for column in kwargs:
|
|
if column not in self._columns:
|
|
msg = f"No such column: {column}"
|
|
raise TypeError(msg)
|
|
conditions.append("%s=?" % column)
|
|
|
|
query = "SELECT * FROM session WHERE %s" % (" AND ".join(conditions)) # noqa: S608
|
|
|
|
self.cursor.execute(query, list(kwargs.values()))
|
|
try:
|
|
row = self.cursor.fetchone()
|
|
except KeyError:
|
|
# The kernel is missing, so the session just got deleted.
|
|
row = None
|
|
|
|
if row is None:
|
|
q = []
|
|
for key, value in kwargs.items():
|
|
q.append(f"{key}={value!r}")
|
|
|
|
raise web.HTTPError(404, "Session not found: %s" % (", ".join(q)))
|
|
|
|
try:
|
|
model = await self.row_to_model(row)
|
|
except KeyError as e:
|
|
raise web.HTTPError(404, "Session not found: %s" % str(e)) from e
|
|
return model
|
|
|
|
async def update_session(self, session_id, **kwargs):
|
|
"""Updates the values in the session database.
|
|
|
|
Changes the values of the session with the given session_id
|
|
with the values from the keyword arguments.
|
|
|
|
Parameters
|
|
----------
|
|
session_id : str
|
|
a uuid that identifies a session in the sqlite3 database
|
|
**kwargs : str
|
|
the key must correspond to a column title in session database,
|
|
and the value replaces the current value in the session
|
|
with session_id.
|
|
"""
|
|
await self.get_session(session_id=session_id)
|
|
|
|
if not kwargs:
|
|
# no changes
|
|
return
|
|
|
|
sets = []
|
|
for column in kwargs:
|
|
if column not in self._columns:
|
|
raise TypeError("No such column: %r" % column)
|
|
sets.append("%s=?" % column)
|
|
query = "UPDATE session SET %s WHERE session_id=?" % (", ".join(sets)) # noqa: S608
|
|
self.cursor.execute(query, [*list(kwargs.values()), session_id])
|
|
|
|
if hasattr(self.kernel_manager, "update_env"):
|
|
self.cursor.execute(
|
|
"SELECT path, name, kernel_id FROM session WHERE session_id=?", [session_id]
|
|
)
|
|
path, name, kernel_id = self.cursor.fetchone()
|
|
self.kernel_manager.update_env(kernel_id=kernel_id, env=self.get_kernel_env(path, name))
|
|
|
|
async def kernel_culled(self, kernel_id: str) -> bool:
|
|
"""Checks if the kernel is still considered alive and returns true if its not found."""
|
|
return kernel_id not in self.kernel_manager
|
|
|
|
async def row_to_model(self, row, tolerate_culled=False):
|
|
"""Takes sqlite database session row and turns it into a dictionary"""
|
|
kernel_culled: bool = await ensure_async(self.kernel_culled(row["kernel_id"]))
|
|
if kernel_culled:
|
|
# The kernel was culled or died without deleting the session.
|
|
# We can't use delete_session here because that tries to find
|
|
# and shut down the kernel - so we'll delete the row directly.
|
|
#
|
|
# If caller wishes to tolerate culled kernels, log a warning
|
|
# and return None. Otherwise, raise KeyError with a similar
|
|
# message.
|
|
self.cursor.execute("DELETE FROM session WHERE session_id=?", (row["session_id"],))
|
|
msg = (
|
|
"Kernel '{kernel_id}' appears to have been culled or died unexpectedly, "
|
|
"invalidating session '{session_id}'. The session has been removed.".format(
|
|
kernel_id=row["kernel_id"], session_id=row["session_id"]
|
|
)
|
|
)
|
|
if tolerate_culled:
|
|
self.log.warning(f"{msg} Continuing...")
|
|
return None
|
|
raise KeyError(msg)
|
|
|
|
kernel_model = await ensure_async(self.kernel_manager.kernel_model(row["kernel_id"]))
|
|
model = {
|
|
"id": row["session_id"],
|
|
"path": row["path"],
|
|
"name": row["name"],
|
|
"type": row["type"],
|
|
"kernel": kernel_model,
|
|
}
|
|
if row["type"] == "notebook":
|
|
# Provide the deprecated API.
|
|
model["notebook"] = {"path": row["path"], "name": row["name"]}
|
|
return model
|
|
|
|
async def list_sessions(self):
|
|
"""Returns a list of dictionaries containing all the information from
|
|
the session database"""
|
|
c = self.cursor.execute("SELECT * FROM session")
|
|
result = []
|
|
# We need to use fetchall() here, because row_to_model can delete rows,
|
|
# which messes up the cursor if we're iterating over rows.
|
|
for row in c.fetchall():
|
|
try:
|
|
model = await self.row_to_model(row)
|
|
result.append(model)
|
|
except KeyError:
|
|
pass
|
|
return result
|
|
|
|
async def delete_session(self, session_id):
|
|
"""Deletes the row in the session database with given session_id"""
|
|
record = KernelSessionRecord(session_id=session_id)
|
|
self._pending_sessions.update(record)
|
|
session = await self.get_session(session_id=session_id)
|
|
await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
|
|
self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
|
|
self._pending_sessions.remove(record)
|