Primer commit del proyecto RSS

This commit is contained in:
jlimolina 2025-05-24 14:37:58 +02:00
commit 27c9515d29
1568 changed files with 252311 additions and 0 deletions

View file

@ -0,0 +1,255 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""MySQL Connector/Python - MySQL driver written in Python."""
__all__ = ["CMySQLConnection", "MySQLConnection", "connect"]
import random
from typing import Any
from ..constants import DEFAULT_CONFIGURATION
from ..errors import Error, InterfaceError, ProgrammingError
from ..pooling import ERROR_NO_CEXT
from .abstracts import MySQLConnectionAbstract
from .connection import MySQLConnection
try:
import dns.exception
import dns.resolver
except ImportError:
HAVE_DNSPYTHON = False
else:
HAVE_DNSPYTHON = True
try:
from .connection_cext import CMySQLConnection
except ImportError:
CMySQLConnection = None
async def connect(*args: Any, **kwargs: Any) -> MySQLConnectionAbstract:
"""Creates or gets a MySQL connection object.
In its simpliest form, `connect()` will open a connection to a
MySQL server and return a `MySQLConnectionAbstract` subclass
object such as `MySQLConnection` or `CMySQLConnection`.
When any connection pooling arguments are given, for example `pool_name`
or `pool_size`, a pool is created or a previously one is used to return
a `PooledMySQLConnection`.
Args:
*args: N/A.
**kwargs: For a complete list of possible arguments, see [1]. If no arguments
are given, it uses the already configured or default values.
Returns:
A `MySQLConnectionAbstract` subclass instance (such as `MySQLConnection` or
a `CMySQLConnection`) instance.
Examples:
A connection with the MySQL server can be established using either the
`mysql.connector.connect()` method or a `MySQLConnectionAbstract` subclass:
```
>>> from mysql.connector.aio import MySQLConnection, HAVE_CEXT
>>>
>>> cnx1 = await mysql.connector.aio.connect(user='joe', database='test')
>>> cnx2 = MySQLConnection(user='joe', database='test')
>>> await cnx2.connect()
>>>
>>> cnx3 = None
>>> if HAVE_CEXT:
>>> from mysql.connector.aio import CMySQLConnection
>>> cnx3 = CMySQLConnection(user='joe', database='test')
```
References:
[1]: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html
"""
# DNS SRV
dns_srv = kwargs.pop("dns_srv") if "dns_srv" in kwargs else False
if not isinstance(dns_srv, bool):
raise InterfaceError("The value of 'dns-srv' must be a boolean")
if dns_srv:
if not HAVE_DNSPYTHON:
raise InterfaceError(
"MySQL host configuration requested DNS "
"SRV. This requires the Python dnspython "
"module. Please refer to documentation"
)
if "unix_socket" in kwargs:
raise InterfaceError(
"Using Unix domain sockets with DNS SRV lookup is not allowed"
)
if "port" in kwargs:
raise InterfaceError(
"Specifying a port number with DNS SRV lookup is not allowed"
)
if "failover" in kwargs:
raise InterfaceError(
"Specifying multiple hostnames with DNS SRV look up is not allowed"
)
if "host" not in kwargs:
kwargs["host"] = DEFAULT_CONFIGURATION["host"]
try:
srv_records = dns.resolver.query(kwargs["host"], "SRV")
except dns.exception.DNSException:
raise InterfaceError(
f"Unable to locate any hosts for '{kwargs['host']}'"
) from None
failover = []
for srv in srv_records:
failover.append(
{
"host": srv.target.to_text(omit_final_dot=True),
"port": srv.port,
"priority": srv.priority,
"weight": srv.weight,
}
)
failover.sort(key=lambda x: (x["priority"], -x["weight"]))
kwargs["failover"] = [
{"host": srv["host"], "port": srv["port"]} for srv in failover
]
# Failover
if "failover" in kwargs:
return await _get_failover_connection(**kwargs)
# Use C Extension by default
use_pure = kwargs.get("use_pure", False)
if "use_pure" in kwargs:
del kwargs["use_pure"] # Remove 'use_pure' from kwargs
if not use_pure and CMySQLConnection is None:
raise ImportError(ERROR_NO_CEXT)
if CMySQLConnection and not use_pure:
cnx = CMySQLConnection(*args, **kwargs)
else:
cnx = MySQLConnection(*args, **kwargs)
await cnx.connect()
return cnx
async def _get_failover_connection(**kwargs: Any) -> MySQLConnectionAbstract:
"""Return a MySQL connection and try to failover if needed.
An InterfaceError is raise when no MySQL is available. ValueError is
raised when the failover server configuration contains an illegal
connection argument. Supported arguments are user, password, host, port,
unix_socket and database. ValueError is also raised when the failover
argument was not provided.
Returns MySQLConnection instance.
"""
config = kwargs.copy()
try:
failover = config["failover"]
except KeyError:
raise ValueError("failover argument not provided") from None
del config["failover"]
support_cnx_args = set(
[
"user",
"password",
"host",
"port",
"unix_socket",
"database",
"pool_name",
"pool_size",
"priority",
]
)
# First check if we can add all use the configuration
priority_count = 0
for server in failover:
diff = set(server.keys()) - support_cnx_args
if diff:
arg = "s" if len(diff) > 1 else ""
lst = ", ".join(diff)
raise ValueError(
f"Unsupported connection argument {arg} in failover: {lst}"
)
if hasattr(server, "priority"):
priority_count += 1
server["priority"] = server.get("priority", 100)
if server["priority"] < 0 or server["priority"] > 100:
raise InterfaceError(
"Priority value should be in the range of 0 to 100, "
f"got : {server['priority']}"
)
if not isinstance(server["priority"], int):
raise InterfaceError(
"Priority value should be an integer in the range of 0 to "
f"100, got : {server['priority']}"
)
if 0 < priority_count < len(failover):
raise ProgrammingError(
"You must either assign no priority to any "
"of the routers or give a priority for "
"every router"
)
server_directory = {}
server_priority_list = []
for server in sorted(failover, key=lambda x: x["priority"], reverse=True):
if server["priority"] not in server_directory:
server_directory[server["priority"]] = [server]
server_priority_list.append(server["priority"])
else:
server_directory[server["priority"]].append(server)
for priority in server_priority_list:
failover_list = server_directory[priority]
for _ in range(len(failover_list)):
last = len(failover_list) - 1
index = random.randint(0, last)
server = failover_list.pop(index)
new_config = config.copy()
new_config.update(server)
new_config.pop("priority", None)
try:
return await connect(**new_config)
except Error:
# If we failed to connect, we try the next server
pass
raise InterfaceError("Unable to connect to any of the target hosts")

View file

@ -0,0 +1,112 @@
# Copyright (c) 2009, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Decorators Hub."""
import functools
import warnings
from typing import TYPE_CHECKING, Any, Callable
from ..constants import RefreshOption
from ..errors import ReadTimeoutError, WriteTimeoutError
if TYPE_CHECKING:
from .abstracts import MySQLConnectionAbstract
def cmd_refresh_verify_options() -> Callable:
"""Decorator verifying which options are relevant and which aren't based on
the server version the client is connecting to."""
def decorator(cmd_refresh: Callable) -> Callable:
@functools.wraps(cmd_refresh)
async def wrapper(
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
) -> Callable:
options: int = args[0]
if (options & RefreshOption.GRANT) and cnx.server_version >= (
9,
2,
0,
):
warnings.warn(
"As of MySQL Server 9.2.0, refreshing grant tables is not needed "
"if you use statements GRANT, REVOKE, CREATE, DROP, or ALTER. "
"You should expect this option to be unsupported in a future "
"version of MySQL Connector/Python when MySQL Server removes it.",
category=DeprecationWarning,
stacklevel=1,
)
return await cmd_refresh(cnx, options, **kwargs)
return wrapper
return decorator
def deprecated(reason: str) -> Callable:
"""Use it to decorate deprecated methods."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
warnings.warn(
f"Call to deprecated function {func.__name__}. Reason: {reason}",
category=DeprecationWarning,
stacklevel=2,
)
return await func(*args, **kwargs)
return wrapper
return decorator
def handle_read_write_timeout() -> Callable:
"""
Decorator to close the current connection if a read or a write timeout
is raised by the method passed via the func parameter.
"""
def decorator(cnx_method: Callable) -> Callable:
@functools.wraps(cnx_method)
async def handle_cnx_method(
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
) -> Any:
try:
return await cnx_method(cnx, *args, **kwargs)
except Exception as err:
if isinstance(err, (ReadTimeoutError, WriteTimeoutError)):
await cnx.close()
raise err
return handle_cnx_method
return decorator

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,335 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implementing support for MySQL Authentication Plugins."""
from __future__ import annotations
__all__ = ["MySQLAuthenticator"]
from typing import TYPE_CHECKING, Any, Dict, Optional
from ..errors import InterfaceError, NotSupportedError, get_exception
from ..protocol import (
AUTH_SWITCH_STATUS,
DEFAULT_CHARSET_ID,
DEFAULT_MAX_ALLOWED_PACKET,
ERR_STATUS,
EXCHANGE_FURTHER_STATUS,
MFA_STATUS,
OK_STATUS,
)
from ..types import HandShakeType
from .logger import logger
from .plugins import MySQLAuthPlugin, get_auth_plugin
from .protocol import MySQLProtocol
if TYPE_CHECKING:
from .network import MySQLSocket
class MySQLAuthenticator:
"""Implements the authentication phase."""
def __init__(self) -> None:
"""Constructor."""
self._username: str = ""
self._passwords: Dict[int, str] = {}
self._plugin_config: Dict[str, Any] = {}
self._ssl_enabled: bool = False
self._auth_strategy: Optional[MySQLAuthPlugin] = None
self._auth_plugin_class: Optional[str] = None
@property
def ssl_enabled(self) -> bool:
"""Signals whether or not SSL is enabled."""
return self._ssl_enabled
@property
def plugin_config(self) -> Dict[str, Any]:
"""Custom arguments that are being provided to the authentication plugin.
The parameters defined here will override the ones defined in the
auth plugin itself.
The plugin config is a read-only property - the plugin configuration
provided when invoking `authenticate()` is recorded and can be queried
by accessing this property.
Returns:
dict: The latest plugin configuration provided when invoking
`authenticate()`.
"""
return self._plugin_config
def update_plugin_config(self, config: Dict[str, Any]) -> None:
"""Update the 'plugin_config' instance variable"""
self._plugin_config.update(config)
def _switch_auth_strategy(
self,
new_strategy_name: str,
strategy_class: Optional[str] = None,
username: Optional[str] = None,
password_factor: int = 1,
) -> None:
"""Switch the authorization plugin.
Args:
new_strategy_name: New authorization plugin name to switch to.
strategy_class: New authorization plugin class to switch to
(has higher precedence than the authorization plugin name).
username: Username to be used - if not defined, the username
provided when `authentication()` was invoked is used.
password_factor: Up to three levels of authentication (MFA) are allowed,
hence you can choose the password corresponding to the 1st,
2nd, or 3rd factor - 1st is the default.
"""
if username is None:
username = self._username
if strategy_class is None:
strategy_class = self._auth_plugin_class
logger.debug("Switching to strategy %s", new_strategy_name)
self._auth_strategy = get_auth_plugin(
plugin_name=new_strategy_name, auth_plugin_class=strategy_class
)(
username,
self._passwords.get(password_factor, ""),
ssl_enabled=self.ssl_enabled,
)
async def _mfa_n_factor(
self,
sock: MySQLSocket,
pkt: bytes,
) -> Optional[bytes]:
"""Handle MFA (Multi-Factor Authentication) response.
Up to three levels of authentication (MFA) are allowed.
Args:
sock: Pointer to the socket connection.
pkt: MFA response.
Returns:
ok_packet: If last server's response is an OK packet.
None: If last server's response isn't an OK packet and no ERROR was raised.
Raises:
InterfaceError: If got an invalid N factor.
errors.ErrorTypes: If got an ERROR response.
"""
n_factor = 2
while pkt[4] == MFA_STATUS:
if n_factor not in self._passwords:
raise InterfaceError(
"Failed Multi Factor Authentication (invalid N factor)"
)
new_strategy_name, auth_data = MySQLProtocol.parse_auth_next_factor(pkt)
self._switch_auth_strategy(new_strategy_name, password_factor=n_factor)
logger.debug("MFA %i factor %s", n_factor, self._auth_strategy.name)
pkt = await self._auth_strategy.auth_switch_response(
sock, auth_data, **self._plugin_config
)
if pkt[4] == EXCHANGE_FURTHER_STATUS:
auth_data = MySQLProtocol.parse_auth_more_data(pkt)
pkt = await self._auth_strategy.auth_more_response(
sock, auth_data, **self._plugin_config
)
if pkt[4] == OK_STATUS:
logger.debug("MFA completed succesfully")
return pkt
if pkt[4] == ERR_STATUS:
raise get_exception(pkt)
n_factor += 1
logger.warning("MFA terminated with a no ok packet")
return None
async def _handle_server_response(
self,
sock: MySQLSocket,
pkt: bytes,
) -> Optional[bytes]:
"""Handle server's response.
Args:
sock: Pointer to the socket connection.
pkt: Server's response after completing the `HandShakeResponse`.
Returns:
ok_packet: If last server's response is an OK packet.
None: If last server's response isn't an OK packet and no ERROR was raised.
Raises:
errors.ErrorTypes: If got an ERROR response.
NotSupportedError: If got Authentication with old (insecure) passwords.
"""
if pkt[4] == AUTH_SWITCH_STATUS and len(pkt) == 5:
raise NotSupportedError(
"Authentication with old (insecure) passwords "
"is not supported. For more information, lookup "
"Password Hashing in the latest MySQL manual"
)
if pkt[4] == AUTH_SWITCH_STATUS:
logger.debug("Server's response is an auth switch request")
new_strategy_name, auth_data = MySQLProtocol.parse_auth_switch_request(pkt)
self._switch_auth_strategy(new_strategy_name)
pkt = await self._auth_strategy.auth_switch_response(
sock, auth_data, **self._plugin_config
)
if pkt[4] == EXCHANGE_FURTHER_STATUS:
logger.debug("Exchanging further packets")
auth_data = MySQLProtocol.parse_auth_more_data(pkt)
pkt = await self._auth_strategy.auth_more_response(
sock, auth_data, **self._plugin_config
)
if pkt[4] == OK_STATUS:
logger.debug("%s completed succesfully", self._auth_strategy.name)
return pkt
if pkt[4] == MFA_STATUS:
logger.debug("Starting multi-factor authentication")
logger.debug("MFA 1 factor %s", self._auth_strategy.name)
return await self._mfa_n_factor(sock, pkt)
if pkt[4] == ERR_STATUS:
raise get_exception(pkt)
return None
async def authenticate(
self,
sock: MySQLSocket,
handshake: HandShakeType,
username: str = "",
password1: str = "",
password2: str = "",
password3: str = "",
database: Optional[str] = None,
charset: int = DEFAULT_CHARSET_ID,
client_flags: int = 0,
ssl_enabled: bool = False,
max_allowed_packet: int = DEFAULT_MAX_ALLOWED_PACKET,
auth_plugin: Optional[str] = None,
auth_plugin_class: Optional[str] = None,
conn_attrs: Optional[Dict[str, str]] = None,
is_change_user_request: bool = False,
read_timeout: Optional[int] = None,
write_timeout: Optional[int] = None,
) -> bytes:
"""Perform the authentication phase.
During re-authentication you must set `is_change_user_request` to True.
Args:
sock: Pointer to the socket connection.
handshake: Initial handshake.
username: Account's username.
password1: Account's password factor 1.
password2: Account's password factor 2.
password3: Account's password factor 3.
database: Initial database name for the connection.
charset: Client charset (see [1]), only the lower 8-bits.
client_flags: Integer representing client capabilities flags.
ssl_enabled: Boolean indicating whether SSL is enabled,
max_allowed_packet: Maximum packet size.
auth_plugin: Authorization plugin name.
auth_plugin_class: Authorization plugin class (has higher precedence
than the authorization plugin name).
conn_attrs: Connection attributes.
is_change_user_request: Whether is a `change user request` operation or not.
read_timeout: Timeout in seconds upto which the connector should wait for
the server to reply back before raising an ReadTimeoutError.
write_timeout: Timeout in seconds upto which the connector should spend to
send data to the server before raising an WriteTimeoutError.
Returns:
ok_packet: OK packet.
Raises:
InterfaceError: If OK packet is NULL.
ReadTimeoutError: If the time taken for the server to reply back exceeds
'read_timeout' (if set).
WriteTimeoutError: If the time taken to send data packets to the server
exceeds 'write_timeout' (if set).
References:
[1]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
page_protocol_basic_character_set.html#a_protocol_character_set
"""
# update credentials, plugin config and plugin class
self._username = username
self._passwords = {1: password1, 2: password2, 3: password3}
self._ssl_enabled = ssl_enabled
self._auth_plugin_class = auth_plugin_class
# client's handshake response
response_payload, self._auth_strategy = MySQLProtocol.make_auth(
handshake=handshake,
username=username,
password=password1,
database=database,
charset=charset,
client_flags=client_flags,
max_allowed_packet=max_allowed_packet,
auth_plugin=auth_plugin,
auth_plugin_class=auth_plugin_class,
conn_attrs=conn_attrs,
is_change_user_request=is_change_user_request,
ssl_enabled=self.ssl_enabled,
plugin_config=self.plugin_config,
)
# client sends transaction response
send_args = (
(0, 0, write_timeout)
if is_change_user_request
else (None, None, write_timeout)
)
await sock.write(response_payload, *send_args)
# server replies back
pkt = bytes(await sock.read(read_timeout))
ok_pkt = await self._handle_server_response(sock, pkt)
if ok_pkt is None:
raise InterfaceError("Got a NULL ok_pkt") from None
return ok_pkt

View file

@ -0,0 +1,686 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""This module contains the MySQL Server Character Sets."""
__all__ = ["Charset", "charsets"]
from collections import defaultdict
from dataclasses import dataclass
from typing import DefaultDict, Dict, Optional, Sequence, Tuple
from ..errors import ProgrammingError
@dataclass
class Charset:
"""Dataclass representing a character set."""
charset_id: int
name: str
collation: str
is_default: bool
class Charsets:
"""MySQL supported character sets and collations class.
This class holds the list of character sets with their collations supported by
MySQL, making available methods to get character sets by name, collation, or ID.
It uses a sparse matrix or tree-like representation using a dict in a dict to hold
the character set name and collations combinations.
The list is hardcoded, so we avoid a database query when getting the name of the
used character set or collation.
The call of ``charsets.set_mysql_major_version()`` should be done before using any
of the retrieval methods.
Usage:
>>> from mysql.connector.aio.charsets import charsets
>>> charsets.set_mysql_major_version(8)
>>> charsets.get_by_name("utf-8")
Charset(charset_id=255,
name='utf8mb4',
collation='utf8mb4_0900_ai_ci',
is_default=True)
"""
def __init__(self) -> None:
self._charset_id_store: Dict[int, Charset] = {}
self._collation_store: Dict[str, Charset] = {}
self._name_store: DefaultDict[str, Dict[str, Charset]] = defaultdict(dict)
self._mysql_major_version: Optional[int] = None
def set_mysql_major_version(self, version: int) -> None:
"""Set the MySQL major version.
Sets what tuple should be used based on the MySQL major version to store the
list of character sets and collations.
Args:
version: The MySQL major version (i.e. 8 or 5)
"""
self._mysql_major_version = version
self._charset_id_store.clear()
self._collation_store.clear()
self._name_store.clear()
charsets_tuple: Sequence[Tuple[int, str, str, bool]] = None
if version >= 8:
charsets_tuple = MYSQL_8_CHARSETS
elif version == 5:
charsets_tuple = MYSQL_5_CHARSETS
else:
raise ProgrammingError("Invalid MySQL major version")
for charset_id, name, collation, is_default in charsets_tuple:
charset = Charset(charset_id, name, collation, is_default)
self._charset_id_store[charset_id] = charset
self._collation_store[collation] = charset
self._name_store[name][collation] = charset
def get_by_id(self, charset_id: int) -> Charset:
"""Get character set by ID.
Args:
charset_id: The charset ID.
Returns:
Charset: The Charset dataclass instance.
"""
try:
return self._charset_id_store[charset_id]
except KeyError as err:
raise ProgrammingError(f"Character set ID {charset_id} unknown") from err
def get_by_collation(self, collation: str) -> Charset:
"""Get character set by collation.
Args:
collation: The collation name.
Returns:
Charset: The Charset dataclass instance.
"""
try:
return self._collation_store[collation]
except KeyError as err:
raise ProgrammingError(f"Collation {collation} unknown") from err
def get_by_name(self, name: str) -> Charset:
"""Get character set by name.
Args:
name: The charset name.
Returns:
Charset: The Charset dataclass instance.
"""
try:
if name in ("utf8", "utf-8") and self._mysql_major_version == 8:
name = "utf8mb4"
for charset in self._name_store[name].values():
if charset.is_default:
return charset
except KeyError as err:
raise ProgrammingError(f"Character set name {name} unknown") from err
raise ProgrammingError(f"No default was found for character set '{name}'")
def get_by_name_and_collation(self, name: str, collation: str) -> Charset:
"""Get character set by name and collation.
Args:
name: The charset name.
collation: The collation name.
Returns:
Charset: The Charset dataclass instance.
"""
try:
return self._name_store[name][collation]
except KeyError as err:
raise ProgrammingError(
f"Character set name '{name}' with collation '{collation}' not found"
) from err
MYSQL_8_CHARSETS = (
(1, "big5", "big5_chinese_ci", True),
(2, "latin2", "latin2_czech_cs", False),
(3, "dec8", "dec8_swedish_ci", True),
(4, "cp850", "cp850_general_ci", True),
(5, "latin1", "latin1_german1_ci", False),
(6, "hp8", "hp8_english_ci", True),
(7, "koi8r", "koi8r_general_ci", True),
(8, "latin1", "latin1_swedish_ci", True),
(9, "latin2", "latin2_general_ci", True),
(10, "swe7", "swe7_swedish_ci", True),
(11, "ascii", "ascii_general_ci", True),
(12, "ujis", "ujis_japanese_ci", True),
(13, "sjis", "sjis_japanese_ci", True),
(14, "cp1251", "cp1251_bulgarian_ci", False),
(15, "latin1", "latin1_danish_ci", False),
(16, "hebrew", "hebrew_general_ci", True),
(18, "tis620", "tis620_thai_ci", True),
(19, "euckr", "euckr_korean_ci", True),
(20, "latin7", "latin7_estonian_cs", False),
(21, "latin2", "latin2_hungarian_ci", False),
(22, "koi8u", "koi8u_general_ci", True),
(23, "cp1251", "cp1251_ukrainian_ci", False),
(24, "gb2312", "gb2312_chinese_ci", True),
(25, "greek", "greek_general_ci", True),
(26, "cp1250", "cp1250_general_ci", True),
(27, "latin2", "latin2_croatian_ci", False),
(28, "gbk", "gbk_chinese_ci", True),
(29, "cp1257", "cp1257_lithuanian_ci", False),
(30, "latin5", "latin5_turkish_ci", True),
(31, "latin1", "latin1_german2_ci", False),
(32, "armscii8", "armscii8_general_ci", True),
(33, "utf8mb3", "utf8mb3_general_ci", True),
(34, "cp1250", "cp1250_czech_cs", False),
(35, "ucs2", "ucs2_general_ci", True),
(36, "cp866", "cp866_general_ci", True),
(37, "keybcs2", "keybcs2_general_ci", True),
(38, "macce", "macce_general_ci", True),
(39, "macroman", "macroman_general_ci", True),
(40, "cp852", "cp852_general_ci", True),
(41, "latin7", "latin7_general_ci", True),
(42, "latin7", "latin7_general_cs", False),
(43, "macce", "macce_bin", False),
(44, "cp1250", "cp1250_croatian_ci", False),
(45, "utf8mb4", "utf8mb4_general_ci", False),
(46, "utf8mb4", "utf8mb4_bin", False),
(47, "latin1", "latin1_bin", False),
(48, "latin1", "latin1_general_ci", False),
(49, "latin1", "latin1_general_cs", False),
(50, "cp1251", "cp1251_bin", False),
(51, "cp1251", "cp1251_general_ci", True),
(52, "cp1251", "cp1251_general_cs", False),
(53, "macroman", "macroman_bin", False),
(54, "utf16", "utf16_general_ci", True),
(55, "utf16", "utf16_bin", False),
(56, "utf16le", "utf16le_general_ci", True),
(57, "cp1256", "cp1256_general_ci", True),
(58, "cp1257", "cp1257_bin", False),
(59, "cp1257", "cp1257_general_ci", True),
(60, "utf32", "utf32_general_ci", True),
(61, "utf32", "utf32_bin", False),
(62, "utf16le", "utf16le_bin", False),
(63, "binary", "binary", True),
(64, "armscii8", "armscii8_bin", False),
(65, "ascii", "ascii_bin", False),
(66, "cp1250", "cp1250_bin", False),
(67, "cp1256", "cp1256_bin", False),
(68, "cp866", "cp866_bin", False),
(69, "dec8", "dec8_bin", False),
(70, "greek", "greek_bin", False),
(71, "hebrew", "hebrew_bin", False),
(72, "hp8", "hp8_bin", False),
(73, "keybcs2", "keybcs2_bin", False),
(74, "koi8r", "koi8r_bin", False),
(75, "koi8u", "koi8u_bin", False),
(76, "utf8mb3", "utf8mb3_tolower_ci", False),
(77, "latin2", "latin2_bin", False),
(78, "latin5", "latin5_bin", False),
(79, "latin7", "latin7_bin", False),
(80, "cp850", "cp850_bin", False),
(81, "cp852", "cp852_bin", False),
(82, "swe7", "swe7_bin", False),
(83, "utf8mb3", "utf8mb3_bin", False),
(84, "big5", "big5_bin", False),
(85, "euckr", "euckr_bin", False),
(86, "gb2312", "gb2312_bin", False),
(87, "gbk", "gbk_bin", False),
(88, "sjis", "sjis_bin", False),
(89, "tis620", "tis620_bin", False),
(90, "ucs2", "ucs2_bin", False),
(91, "ujis", "ujis_bin", False),
(92, "geostd8", "geostd8_general_ci", True),
(93, "geostd8", "geostd8_bin", False),
(94, "latin1", "latin1_spanish_ci", False),
(95, "cp932", "cp932_japanese_ci", True),
(96, "cp932", "cp932_bin", False),
(97, "eucjpms", "eucjpms_japanese_ci", True),
(98, "eucjpms", "eucjpms_bin", False),
(99, "cp1250", "cp1250_polish_ci", False),
(101, "utf16", "utf16_unicode_ci", False),
(102, "utf16", "utf16_icelandic_ci", False),
(103, "utf16", "utf16_latvian_ci", False),
(104, "utf16", "utf16_romanian_ci", False),
(105, "utf16", "utf16_slovenian_ci", False),
(106, "utf16", "utf16_polish_ci", False),
(107, "utf16", "utf16_estonian_ci", False),
(108, "utf16", "utf16_spanish_ci", False),
(109, "utf16", "utf16_swedish_ci", False),
(110, "utf16", "utf16_turkish_ci", False),
(111, "utf16", "utf16_czech_ci", False),
(112, "utf16", "utf16_danish_ci", False),
(113, "utf16", "utf16_lithuanian_ci", False),
(114, "utf16", "utf16_slovak_ci", False),
(115, "utf16", "utf16_spanish2_ci", False),
(116, "utf16", "utf16_roman_ci", False),
(117, "utf16", "utf16_persian_ci", False),
(118, "utf16", "utf16_esperanto_ci", False),
(119, "utf16", "utf16_hungarian_ci", False),
(120, "utf16", "utf16_sinhala_ci", False),
(121, "utf16", "utf16_german2_ci", False),
(122, "utf16", "utf16_croatian_ci", False),
(123, "utf16", "utf16_unicode_520_ci", False),
(124, "utf16", "utf16_vietnamese_ci", False),
(128, "ucs2", "ucs2_unicode_ci", False),
(129, "ucs2", "ucs2_icelandic_ci", False),
(130, "ucs2", "ucs2_latvian_ci", False),
(131, "ucs2", "ucs2_romanian_ci", False),
(132, "ucs2", "ucs2_slovenian_ci", False),
(133, "ucs2", "ucs2_polish_ci", False),
(134, "ucs2", "ucs2_estonian_ci", False),
(135, "ucs2", "ucs2_spanish_ci", False),
(136, "ucs2", "ucs2_swedish_ci", False),
(137, "ucs2", "ucs2_turkish_ci", False),
(138, "ucs2", "ucs2_czech_ci", False),
(139, "ucs2", "ucs2_danish_ci", False),
(140, "ucs2", "ucs2_lithuanian_ci", False),
(141, "ucs2", "ucs2_slovak_ci", False),
(142, "ucs2", "ucs2_spanish2_ci", False),
(143, "ucs2", "ucs2_roman_ci", False),
(144, "ucs2", "ucs2_persian_ci", False),
(145, "ucs2", "ucs2_esperanto_ci", False),
(146, "ucs2", "ucs2_hungarian_ci", False),
(147, "ucs2", "ucs2_sinhala_ci", False),
(148, "ucs2", "ucs2_german2_ci", False),
(149, "ucs2", "ucs2_croatian_ci", False),
(150, "ucs2", "ucs2_unicode_520_ci", False),
(151, "ucs2", "ucs2_vietnamese_ci", False),
(159, "ucs2", "ucs2_general_mysql500_ci", False),
(160, "utf32", "utf32_unicode_ci", False),
(161, "utf32", "utf32_icelandic_ci", False),
(162, "utf32", "utf32_latvian_ci", False),
(163, "utf32", "utf32_romanian_ci", False),
(164, "utf32", "utf32_slovenian_ci", False),
(165, "utf32", "utf32_polish_ci", False),
(166, "utf32", "utf32_estonian_ci", False),
(167, "utf32", "utf32_spanish_ci", False),
(168, "utf32", "utf32_swedish_ci", False),
(169, "utf32", "utf32_turkish_ci", False),
(170, "utf32", "utf32_czech_ci", False),
(171, "utf32", "utf32_danish_ci", False),
(172, "utf32", "utf32_lithuanian_ci", False),
(173, "utf32", "utf32_slovak_ci", False),
(174, "utf32", "utf32_spanish2_ci", False),
(175, "utf32", "utf32_roman_ci", False),
(176, "utf32", "utf32_persian_ci", False),
(177, "utf32", "utf32_esperanto_ci", False),
(178, "utf32", "utf32_hungarian_ci", False),
(179, "utf32", "utf32_sinhala_ci", False),
(180, "utf32", "utf32_german2_ci", False),
(181, "utf32", "utf32_croatian_ci", False),
(182, "utf32", "utf32_unicode_520_ci", False),
(183, "utf32", "utf32_vietnamese_ci", False),
(192, "utf8mb3", "utf8mb3_unicode_ci", False),
(193, "utf8mb3", "utf8mb3_icelandic_ci", False),
(194, "utf8mb3", "utf8mb3_latvian_ci", False),
(195, "utf8mb3", "utf8mb3_romanian_ci", False),
(196, "utf8mb3", "utf8mb3_slovenian_ci", False),
(197, "utf8mb3", "utf8mb3_polish_ci", False),
(198, "utf8mb3", "utf8mb3_estonian_ci", False),
(199, "utf8mb3", "utf8mb3_spanish_ci", False),
(200, "utf8mb3", "utf8mb3_swedish_ci", False),
(201, "utf8mb3", "utf8mb3_turkish_ci", False),
(202, "utf8mb3", "utf8mb3_czech_ci", False),
(203, "utf8mb3", "utf8mb3_danish_ci", False),
(204, "utf8mb3", "utf8mb3_lithuanian_ci", False),
(205, "utf8mb3", "utf8mb3_slovak_ci", False),
(206, "utf8mb3", "utf8mb3_spanish2_ci", False),
(207, "utf8mb3", "utf8mb3_roman_ci", False),
(208, "utf8mb3", "utf8mb3_persian_ci", False),
(209, "utf8mb3", "utf8mb3_esperanto_ci", False),
(210, "utf8mb3", "utf8mb3_hungarian_ci", False),
(211, "utf8mb3", "utf8mb3_sinhala_ci", False),
(212, "utf8mb3", "utf8mb3_german2_ci", False),
(213, "utf8mb3", "utf8mb3_croatian_ci", False),
(214, "utf8mb3", "utf8mb3_unicode_520_ci", False),
(215, "utf8mb3", "utf8mb3_vietnamese_ci", False),
(223, "utf8mb3", "utf8mb3_general_mysql500_ci", False),
(224, "utf8mb4", "utf8mb4_unicode_ci", False),
(225, "utf8mb4", "utf8mb4_icelandic_ci", False),
(226, "utf8mb4", "utf8mb4_latvian_ci", False),
(227, "utf8mb4", "utf8mb4_romanian_ci", False),
(228, "utf8mb4", "utf8mb4_slovenian_ci", False),
(229, "utf8mb4", "utf8mb4_polish_ci", False),
(230, "utf8mb4", "utf8mb4_estonian_ci", False),
(231, "utf8mb4", "utf8mb4_spanish_ci", False),
(232, "utf8mb4", "utf8mb4_swedish_ci", False),
(233, "utf8mb4", "utf8mb4_turkish_ci", False),
(234, "utf8mb4", "utf8mb4_czech_ci", False),
(235, "utf8mb4", "utf8mb4_danish_ci", False),
(236, "utf8mb4", "utf8mb4_lithuanian_ci", False),
(237, "utf8mb4", "utf8mb4_slovak_ci", False),
(238, "utf8mb4", "utf8mb4_spanish2_ci", False),
(239, "utf8mb4", "utf8mb4_roman_ci", False),
(240, "utf8mb4", "utf8mb4_persian_ci", False),
(241, "utf8mb4", "utf8mb4_esperanto_ci", False),
(242, "utf8mb4", "utf8mb4_hungarian_ci", False),
(243, "utf8mb4", "utf8mb4_sinhala_ci", False),
(244, "utf8mb4", "utf8mb4_german2_ci", False),
(245, "utf8mb4", "utf8mb4_croatian_ci", False),
(246, "utf8mb4", "utf8mb4_unicode_520_ci", False),
(247, "utf8mb4", "utf8mb4_vietnamese_ci", False),
(248, "gb18030", "gb18030_chinese_ci", True),
(249, "gb18030", "gb18030_bin", False),
(250, "gb18030", "gb18030_unicode_520_ci", False),
(255, "utf8mb4", "utf8mb4_0900_ai_ci", True),
(256, "utf8mb4", "utf8mb4_de_pb_0900_ai_ci", False),
(257, "utf8mb4", "utf8mb4_is_0900_ai_ci", False),
(258, "utf8mb4", "utf8mb4_lv_0900_ai_ci", False),
(259, "utf8mb4", "utf8mb4_ro_0900_ai_ci", False),
(260, "utf8mb4", "utf8mb4_sl_0900_ai_ci", False),
(261, "utf8mb4", "utf8mb4_pl_0900_ai_ci", False),
(262, "utf8mb4", "utf8mb4_et_0900_ai_ci", False),
(263, "utf8mb4", "utf8mb4_es_0900_ai_ci", False),
(264, "utf8mb4", "utf8mb4_sv_0900_ai_ci", False),
(265, "utf8mb4", "utf8mb4_tr_0900_ai_ci", False),
(266, "utf8mb4", "utf8mb4_cs_0900_ai_ci", False),
(267, "utf8mb4", "utf8mb4_da_0900_ai_ci", False),
(268, "utf8mb4", "utf8mb4_lt_0900_ai_ci", False),
(269, "utf8mb4", "utf8mb4_sk_0900_ai_ci", False),
(270, "utf8mb4", "utf8mb4_es_trad_0900_ai_ci", False),
(271, "utf8mb4", "utf8mb4_la_0900_ai_ci", False),
(273, "utf8mb4", "utf8mb4_eo_0900_ai_ci", False),
(274, "utf8mb4", "utf8mb4_hu_0900_ai_ci", False),
(275, "utf8mb4", "utf8mb4_hr_0900_ai_ci", False),
(277, "utf8mb4", "utf8mb4_vi_0900_ai_ci", False),
(278, "utf8mb4", "utf8mb4_0900_as_cs", False),
(279, "utf8mb4", "utf8mb4_de_pb_0900_as_cs", False),
(280, "utf8mb4", "utf8mb4_is_0900_as_cs", False),
(281, "utf8mb4", "utf8mb4_lv_0900_as_cs", False),
(282, "utf8mb4", "utf8mb4_ro_0900_as_cs", False),
(283, "utf8mb4", "utf8mb4_sl_0900_as_cs", False),
(284, "utf8mb4", "utf8mb4_pl_0900_as_cs", False),
(285, "utf8mb4", "utf8mb4_et_0900_as_cs", False),
(286, "utf8mb4", "utf8mb4_es_0900_as_cs", False),
(287, "utf8mb4", "utf8mb4_sv_0900_as_cs", False),
(288, "utf8mb4", "utf8mb4_tr_0900_as_cs", False),
(289, "utf8mb4", "utf8mb4_cs_0900_as_cs", False),
(290, "utf8mb4", "utf8mb4_da_0900_as_cs", False),
(291, "utf8mb4", "utf8mb4_lt_0900_as_cs", False),
(292, "utf8mb4", "utf8mb4_sk_0900_as_cs", False),
(293, "utf8mb4", "utf8mb4_es_trad_0900_as_cs", False),
(294, "utf8mb4", "utf8mb4_la_0900_as_cs", False),
(296, "utf8mb4", "utf8mb4_eo_0900_as_cs", False),
(297, "utf8mb4", "utf8mb4_hu_0900_as_cs", False),
(298, "utf8mb4", "utf8mb4_hr_0900_as_cs", False),
(300, "utf8mb4", "utf8mb4_vi_0900_as_cs", False),
(303, "utf8mb4", "utf8mb4_ja_0900_as_cs", False),
(304, "utf8mb4", "utf8mb4_ja_0900_as_cs_ks", False),
(305, "utf8mb4", "utf8mb4_0900_as_ci", False),
(306, "utf8mb4", "utf8mb4_ru_0900_ai_ci", False),
(307, "utf8mb4", "utf8mb4_ru_0900_as_cs", False),
(308, "utf8mb4", "utf8mb4_zh_0900_as_cs", False),
(309, "utf8mb4", "utf8mb4_0900_bin", False),
(310, "utf8mb4", "utf8mb4_nb_0900_ai_ci", False),
(311, "utf8mb4", "utf8mb4_nb_0900_as_cs", False),
(312, "utf8mb4", "utf8mb4_nn_0900_ai_ci", False),
(313, "utf8mb4", "utf8mb4_nn_0900_as_cs", False),
(314, "utf8mb4", "utf8mb4_sr_latn_0900_ai_ci", False),
(315, "utf8mb4", "utf8mb4_sr_latn_0900_as_cs", False),
(316, "utf8mb4", "utf8mb4_bs_0900_ai_ci", False),
(317, "utf8mb4", "utf8mb4_bs_0900_as_cs", False),
(318, "utf8mb4", "utf8mb4_bg_0900_ai_ci", False),
(319, "utf8mb4", "utf8mb4_bg_0900_as_cs", False),
(320, "utf8mb4", "utf8mb4_gl_0900_ai_ci", False),
(321, "utf8mb4", "utf8mb4_gl_0900_as_cs", False),
(322, "utf8mb4", "utf8mb4_mn_cyrl_0900_ai_ci", False),
(323, "utf8mb4", "utf8mb4_mn_cyrl_0900_as_cs", False),
)
MYSQL_5_CHARSETS = (
(1, "big5", "big5_chinese_ci", True),
(2, "latin2", "latin2_czech_cs", False),
(3, "dec8", "dec8_swedish_ci", True),
(4, "cp850", "cp850_general_ci", True),
(5, "latin1", "latin1_german1_ci", False),
(6, "hp8", "hp8_english_ci", True),
(7, "koi8r", "koi8r_general_ci", True),
(8, "latin1", "latin1_swedish_ci", True),
(9, "latin2", "latin2_general_ci", True),
(10, "swe7", "swe7_swedish_ci", True),
(11, "ascii", "ascii_general_ci", True),
(12, "ujis", "ujis_japanese_ci", True),
(13, "sjis", "sjis_japanese_ci", True),
(14, "cp1251", "cp1251_bulgarian_ci", False),
(15, "latin1", "latin1_danish_ci", False),
(16, "hebrew", "hebrew_general_ci", True),
(18, "tis620", "tis620_thai_ci", True),
(19, "euckr", "euckr_korean_ci", True),
(20, "latin7", "latin7_estonian_cs", False),
(21, "latin2", "latin2_hungarian_ci", False),
(22, "koi8u", "koi8u_general_ci", True),
(23, "cp1251", "cp1251_ukrainian_ci", False),
(24, "gb2312", "gb2312_chinese_ci", True),
(25, "greek", "greek_general_ci", True),
(26, "cp1250", "cp1250_general_ci", True),
(27, "latin2", "latin2_croatian_ci", False),
(28, "gbk", "gbk_chinese_ci", True),
(29, "cp1257", "cp1257_lithuanian_ci", False),
(30, "latin5", "latin5_turkish_ci", True),
(31, "latin1", "latin1_german2_ci", False),
(32, "armscii8", "armscii8_general_ci", True),
(33, "utf8", "utf8_general_ci", True),
(34, "cp1250", "cp1250_czech_cs", False),
(35, "ucs2", "ucs2_general_ci", True),
(36, "cp866", "cp866_general_ci", True),
(37, "keybcs2", "keybcs2_general_ci", True),
(38, "macce", "macce_general_ci", True),
(39, "macroman", "macroman_general_ci", True),
(40, "cp852", "cp852_general_ci", True),
(41, "latin7", "latin7_general_ci", True),
(42, "latin7", "latin7_general_cs", False),
(43, "macce", "macce_bin", False),
(44, "cp1250", "cp1250_croatian_ci", False),
(45, "utf8mb4", "utf8mb4_general_ci", True),
(46, "utf8mb4", "utf8mb4_bin", False),
(47, "latin1", "latin1_bin", False),
(48, "latin1", "latin1_general_ci", False),
(49, "latin1", "latin1_general_cs", False),
(50, "cp1251", "cp1251_bin", False),
(51, "cp1251", "cp1251_general_ci", True),
(52, "cp1251", "cp1251_general_cs", False),
(53, "macroman", "macroman_bin", False),
(54, "utf16", "utf16_general_ci", True),
(55, "utf16", "utf16_bin", False),
(56, "utf16le", "utf16le_general_ci", True),
(57, "cp1256", "cp1256_general_ci", True),
(58, "cp1257", "cp1257_bin", False),
(59, "cp1257", "cp1257_general_ci", True),
(60, "utf32", "utf32_general_ci", True),
(61, "utf32", "utf32_bin", False),
(62, "utf16le", "utf16le_bin", False),
(63, "binary", "binary", True),
(64, "armscii8", "armscii8_bin", False),
(65, "ascii", "ascii_bin", False),
(66, "cp1250", "cp1250_bin", False),
(67, "cp1256", "cp1256_bin", False),
(68, "cp866", "cp866_bin", False),
(69, "dec8", "dec8_bin", False),
(70, "greek", "greek_bin", False),
(71, "hebrew", "hebrew_bin", False),
(72, "hp8", "hp8_bin", False),
(73, "keybcs2", "keybcs2_bin", False),
(74, "koi8r", "koi8r_bin", False),
(75, "koi8u", "koi8u_bin", False),
(77, "latin2", "latin2_bin", False),
(78, "latin5", "latin5_bin", False),
(79, "latin7", "latin7_bin", False),
(80, "cp850", "cp850_bin", False),
(81, "cp852", "cp852_bin", False),
(82, "swe7", "swe7_bin", False),
(83, "utf8", "utf8_bin", False),
(84, "big5", "big5_bin", False),
(85, "euckr", "euckr_bin", False),
(86, "gb2312", "gb2312_bin", False),
(87, "gbk", "gbk_bin", False),
(88, "sjis", "sjis_bin", False),
(89, "tis620", "tis620_bin", False),
(90, "ucs2", "ucs2_bin", False),
(91, "ujis", "ujis_bin", False),
(92, "geostd8", "geostd8_general_ci", True),
(93, "geostd8", "geostd8_bin", False),
(94, "latin1", "latin1_spanish_ci", False),
(95, "cp932", "cp932_japanese_ci", True),
(96, "cp932", "cp932_bin", False),
(97, "eucjpms", "eucjpms_japanese_ci", True),
(98, "eucjpms", "eucjpms_bin", False),
(99, "cp1250", "cp1250_polish_ci", False),
(101, "utf16", "utf16_unicode_ci", False),
(102, "utf16", "utf16_icelandic_ci", False),
(103, "utf16", "utf16_latvian_ci", False),
(104, "utf16", "utf16_romanian_ci", False),
(105, "utf16", "utf16_slovenian_ci", False),
(106, "utf16", "utf16_polish_ci", False),
(107, "utf16", "utf16_estonian_ci", False),
(108, "utf16", "utf16_spanish_ci", False),
(109, "utf16", "utf16_swedish_ci", False),
(110, "utf16", "utf16_turkish_ci", False),
(111, "utf16", "utf16_czech_ci", False),
(112, "utf16", "utf16_danish_ci", False),
(113, "utf16", "utf16_lithuanian_ci", False),
(114, "utf16", "utf16_slovak_ci", False),
(115, "utf16", "utf16_spanish2_ci", False),
(116, "utf16", "utf16_roman_ci", False),
(117, "utf16", "utf16_persian_ci", False),
(118, "utf16", "utf16_esperanto_ci", False),
(119, "utf16", "utf16_hungarian_ci", False),
(120, "utf16", "utf16_sinhala_ci", False),
(121, "utf16", "utf16_german2_ci", False),
(122, "utf16", "utf16_croatian_ci", False),
(123, "utf16", "utf16_unicode_520_ci", False),
(124, "utf16", "utf16_vietnamese_ci", False),
(128, "ucs2", "ucs2_unicode_ci", False),
(129, "ucs2", "ucs2_icelandic_ci", False),
(130, "ucs2", "ucs2_latvian_ci", False),
(131, "ucs2", "ucs2_romanian_ci", False),
(132, "ucs2", "ucs2_slovenian_ci", False),
(133, "ucs2", "ucs2_polish_ci", False),
(134, "ucs2", "ucs2_estonian_ci", False),
(135, "ucs2", "ucs2_spanish_ci", False),
(136, "ucs2", "ucs2_swedish_ci", False),
(137, "ucs2", "ucs2_turkish_ci", False),
(138, "ucs2", "ucs2_czech_ci", False),
(139, "ucs2", "ucs2_danish_ci", False),
(140, "ucs2", "ucs2_lithuanian_ci", False),
(141, "ucs2", "ucs2_slovak_ci", False),
(142, "ucs2", "ucs2_spanish2_ci", False),
(143, "ucs2", "ucs2_roman_ci", False),
(144, "ucs2", "ucs2_persian_ci", False),
(145, "ucs2", "ucs2_esperanto_ci", False),
(146, "ucs2", "ucs2_hungarian_ci", False),
(147, "ucs2", "ucs2_sinhala_ci", False),
(148, "ucs2", "ucs2_german2_ci", False),
(149, "ucs2", "ucs2_croatian_ci", False),
(150, "ucs2", "ucs2_unicode_520_ci", False),
(151, "ucs2", "ucs2_vietnamese_ci", False),
(159, "ucs2", "ucs2_general_mysql500_ci", False),
(160, "utf32", "utf32_unicode_ci", False),
(161, "utf32", "utf32_icelandic_ci", False),
(162, "utf32", "utf32_latvian_ci", False),
(163, "utf32", "utf32_romanian_ci", False),
(164, "utf32", "utf32_slovenian_ci", False),
(165, "utf32", "utf32_polish_ci", False),
(166, "utf32", "utf32_estonian_ci", False),
(167, "utf32", "utf32_spanish_ci", False),
(168, "utf32", "utf32_swedish_ci", False),
(169, "utf32", "utf32_turkish_ci", False),
(170, "utf32", "utf32_czech_ci", False),
(171, "utf32", "utf32_danish_ci", False),
(172, "utf32", "utf32_lithuanian_ci", False),
(173, "utf32", "utf32_slovak_ci", False),
(174, "utf32", "utf32_spanish2_ci", False),
(175, "utf32", "utf32_roman_ci", False),
(176, "utf32", "utf32_persian_ci", False),
(177, "utf32", "utf32_esperanto_ci", False),
(178, "utf32", "utf32_hungarian_ci", False),
(179, "utf32", "utf32_sinhala_ci", False),
(180, "utf32", "utf32_german2_ci", False),
(181, "utf32", "utf32_croatian_ci", False),
(182, "utf32", "utf32_unicode_520_ci", False),
(183, "utf32", "utf32_vietnamese_ci", False),
(192, "utf8", "utf8_unicode_ci", False),
(193, "utf8", "utf8_icelandic_ci", False),
(194, "utf8", "utf8_latvian_ci", False),
(195, "utf8", "utf8_romanian_ci", False),
(196, "utf8", "utf8_slovenian_ci", False),
(197, "utf8", "utf8_polish_ci", False),
(198, "utf8", "utf8_estonian_ci", False),
(199, "utf8", "utf8_spanish_ci", False),
(200, "utf8", "utf8_swedish_ci", False),
(201, "utf8", "utf8_turkish_ci", False),
(202, "utf8", "utf8_czech_ci", False),
(203, "utf8", "utf8_danish_ci", False),
(204, "utf8", "utf8_lithuanian_ci", False),
(205, "utf8", "utf8_slovak_ci", False),
(206, "utf8", "utf8_spanish2_ci", False),
(207, "utf8", "utf8_roman_ci", False),
(208, "utf8", "utf8_persian_ci", False),
(209, "utf8", "utf8_esperanto_ci", False),
(210, "utf8", "utf8_hungarian_ci", False),
(211, "utf8", "utf8_sinhala_ci", False),
(212, "utf8", "utf8_german2_ci", False),
(213, "utf8", "utf8_croatian_ci", False),
(214, "utf8", "utf8_unicode_520_ci", False),
(215, "utf8", "utf8_vietnamese_ci", False),
(223, "utf8", "utf8_general_mysql500_ci", False),
(224, "utf8mb4", "utf8mb4_unicode_ci", False),
(225, "utf8mb4", "utf8mb4_icelandic_ci", False),
(226, "utf8mb4", "utf8mb4_latvian_ci", False),
(227, "utf8mb4", "utf8mb4_romanian_ci", False),
(228, "utf8mb4", "utf8mb4_slovenian_ci", False),
(229, "utf8mb4", "utf8mb4_polish_ci", False),
(230, "utf8mb4", "utf8mb4_estonian_ci", False),
(231, "utf8mb4", "utf8mb4_spanish_ci", False),
(232, "utf8mb4", "utf8mb4_swedish_ci", False),
(233, "utf8mb4", "utf8mb4_turkish_ci", False),
(234, "utf8mb4", "utf8mb4_czech_ci", False),
(235, "utf8mb4", "utf8mb4_danish_ci", False),
(236, "utf8mb4", "utf8mb4_lithuanian_ci", False),
(237, "utf8mb4", "utf8mb4_slovak_ci", False),
(238, "utf8mb4", "utf8mb4_spanish2_ci", False),
(239, "utf8mb4", "utf8mb4_roman_ci", False),
(240, "utf8mb4", "utf8mb4_persian_ci", False),
(241, "utf8mb4", "utf8mb4_esperanto_ci", False),
(242, "utf8mb4", "utf8mb4_hungarian_ci", False),
(243, "utf8mb4", "utf8mb4_sinhala_ci", False),
(244, "utf8mb4", "utf8mb4_german2_ci", False),
(245, "utf8mb4", "utf8mb4_croatian_ci", False),
(246, "utf8mb4", "utf8mb4_unicode_520_ci", False),
(247, "utf8mb4", "utf8mb4_vietnamese_ci", False),
(248, "gb18030", "gb18030_chinese_ci", True),
(249, "gb18030", "gb18030_bin", False),
(250, "gb18030", "gb18030_unicode_520_ci", False),
)
charsets = Charsets()

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,33 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Setup of the `mysql.connector.aio` logger."""
import logging
logger = logging.getLogger("mysql.connector.aio")

View file

@ -0,0 +1,761 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# pylint: disable=dangerous-default-value
"""Module implementing low-level socket communication with MySQL servers."""
__all__ = ["MySQLTcpSocket", "MySQLUnixSocket"]
import asyncio
import struct
import zlib
try:
import ssl
TLS_VERSIONS = {
"TLSv1": ssl.PROTOCOL_TLSv1,
"TLSv1.1": ssl.PROTOCOL_TLSv1_1,
"TLSv1.2": ssl.PROTOCOL_TLSv1_2,
"TLSv1.3": ssl.PROTOCOL_TLS,
}
except ImportError:
ssl = None
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Deque, List, Optional, Tuple
from ..errors import (
InterfaceError,
NotSupportedError,
OperationalError,
ProgrammingError,
ReadTimeoutError,
WriteTimeoutError,
)
from ..network import (
COMPRESSED_PACKET_HEADER_LENGTH,
MAX_PAYLOAD_LENGTH,
MIN_COMPRESS_LENGTH,
PACKET_HEADER_LENGTH,
)
from .utils import StreamWriter, open_connection
def _strioerror(err: IOError) -> str:
"""Reformat the IOError error message.
This function reformats the IOError error message.
"""
return str(err) if not err.errno else f"{err.errno} {err.strerror}"
class NetworkBroker(ABC):
"""Broker class interface.
The network object is a broker used as a delegate by a socket object. Whenever the
socket wants to deliver or get packets to or from the MySQL server it needs to rely
on its network broker (netbroker).
The netbroker sends `payloads` and receives `packets`.
A packet is a bytes sequence, it has a header and body (referred to as payload).
The first `PACKET_HEADER_LENGTH` or `COMPRESSED_PACKET_HEADER_LENGTH`
(as appropriate) bytes correspond to the `header`, the remaining ones represent the
`payload`.
The maximum payload length allowed to be sent per packet to the server is
`MAX_PAYLOAD_LENGTH`. When `send` is called with a payload whose length is greater
than `MAX_PAYLOAD_LENGTH` the netbroker breaks it down into packets, so the caller
of `send` can provide payloads of arbitrary length.
Finally, data received by the netbroker comes directly from the server, expect to
get a packet for each call to `recv`. The received packet contains a header and
payload, the latter respecting `MAX_PAYLOAD_LENGTH`.
"""
@abstractmethod
async def write(
self,
writer: StreamWriter,
address: str,
payload: bytes,
packet_number: Optional[int] = None,
compressed_packet_number: Optional[int] = None,
write_timeout: Optional[int] = None,
) -> None:
"""Send `payload` to the MySQL server.
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
broken down into packets.
Args:
sock: Object holding the socket connection.
address: Socket's location.
payload: Packet's body to send.
packet_number: Sequence id (packet ID) to attach to the header when sending
plain packets.
compressed_packet_number: Same as `packet_number` but used when sending
compressed packets.
write_timeout: Timeout in seconds before which sending a packet to the server
should finish else WriteTimeoutError is raised.
Raises:
:class:`OperationalError`: If something goes wrong while sending packets to
the MySQL server.
"""
@abstractmethod
async def read(
self,
reader: asyncio.StreamReader,
address: str,
read_timeout: Optional[int] = None,
) -> bytearray:
"""Get the next available packet from the MySQL server.
Args:
sock: Object holding the socket connection.
address: Socket's location.
read_timeout: Timeout in seconds before which reading a packet from the server
should finish.
Returns:
packet: A packet from the MySQL server.
Raises:
:class:`OperationalError`: If something goes wrong while receiving packets
from the MySQL server.
:class:`ReadTimeoutError`: If the time to receive a packet from the server takes
longer than `read_timeout`.
:class:`InterfaceError`: If something goes wrong while receiving packets
from the MySQL server.
"""
class NetworkBrokerPlain(NetworkBroker):
"""Broker class for MySQL socket communication."""
def __init__(self) -> None:
self._pktnr: int = -1 # packet number
@staticmethod
def get_header(pkt: bytes) -> Tuple[int, int]:
"""Recover the header information from a packet."""
if len(pkt) < PACKET_HEADER_LENGTH:
raise ValueError("Can't recover header info from an incomplete packet")
pll, seqid = (
struct.unpack("<I", pkt[0:3] + b"\x00")[0],
pkt[3],
)
# payload length, sequence id
return pll, seqid
def _set_next_pktnr(self, next_id: Optional[int] = None) -> None:
"""Set the given packet id, if any, else increment packet id."""
if next_id is None:
self._pktnr += 1
else:
self._pktnr = next_id
self._pktnr %= 256
async def _write_pkt(
self,
writer: StreamWriter,
address: str,
pkt: bytes,
) -> None:
"""Write packet to the comm channel."""
try:
writer.write(pkt)
await writer.drain()
except IOError as err:
raise OperationalError(
errno=2055, values=(address, _strioerror(err))
) from err
except AttributeError as err:
raise OperationalError(errno=2006) from err
async def _read_chunk(
self,
reader: asyncio.StreamReader,
size: int = 0,
read_timeout: Optional[int] = None,
) -> bytearray:
"""Read `size` bytes from the comm channel."""
try:
pkt = bytearray(b"")
while len(pkt) < size:
chunk = await asyncio.wait_for(
reader.read(size - len(pkt)), read_timeout
)
if not chunk:
raise InterfaceError(errno=2013)
pkt += chunk
return pkt
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
raise ReadTimeoutError(errno=3024) from err
async def write(
self,
writer: StreamWriter,
address: str,
payload: bytes,
packet_number: Optional[int] = None,
compressed_packet_number: Optional[int] = None,
write_timeout: Optional[int] = None,
) -> None:
"""Send payload to the MySQL server.
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
broken down into packets.
"""
self._set_next_pktnr(packet_number)
# If the payload is larger than or equal to MAX_PAYLOAD_LENGTH the length is
# set to 2^24 - 1 (ff ff ff) and additional packets are sent with the rest of
# the payload until the payload of a packet is less than MAX_PAYLOAD_LENGTH.
offset = 0
try:
for _ in range(len(payload) // MAX_PAYLOAD_LENGTH):
# payload_len, sequence_id, payload
await asyncio.wait_for(
self._write_pkt(
writer,
address,
b"\xff" * 3
+ struct.pack("<B", self._pktnr)
+ payload[offset : offset + MAX_PAYLOAD_LENGTH],
),
write_timeout,
)
self._set_next_pktnr()
offset += MAX_PAYLOAD_LENGTH
await asyncio.wait_for(
self._write_pkt(
writer,
address,
struct.pack("<I", len(payload) - offset)[0:3]
+ struct.pack("<B", self._pktnr)
+ payload[offset:],
),
write_timeout,
)
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
raise WriteTimeoutError(errno=3024) from err
async def read(
self,
reader: asyncio.StreamReader,
address: str,
read_timeout: Optional[int] = None,
) -> bytearray:
"""Receive `one` packet from the MySQL server."""
try:
# Read the header of the MySQL packet.
header = await self._read_chunk(reader, PACKET_HEADER_LENGTH, read_timeout)
# Pull the payload length and sequence id.
payload_len, self._pktnr = self.get_header(header)
# Read the payload, and return packet.
return header + await self._read_chunk(reader, payload_len, read_timeout)
except IOError as err:
raise OperationalError(
errno=2055, values=(address, _strioerror(err))
) from err
class NetworkBrokerCompressed(NetworkBrokerPlain):
"""Broker class for MySQL socket communication."""
def __init__(self) -> None:
super().__init__()
self._compressed_pktnr = -1
self._queue_read: Deque[bytearray] = deque()
@staticmethod
def _prepare_packets(payload: bytes, pktnr: int) -> List[bytes]:
"""Prepare a payload for sending to the MySQL server."""
offset = 0
pkts = []
# If the payload is larger than or equal to MAX_PAYLOAD_LENGTH the length is
# set to 2^24 - 1 (ff ff ff) and additional packets are sent with the rest of
# the payload until the payload of a packet is less than MAX_PAYLOAD_LENGTH.
for _ in range(len(payload) // MAX_PAYLOAD_LENGTH):
# payload length + sequence id + payload
pkts.append(
b"\xff" * 3
+ struct.pack("<B", pktnr)
+ payload[offset : offset + MAX_PAYLOAD_LENGTH]
)
pktnr = (pktnr + 1) % 256
offset += MAX_PAYLOAD_LENGTH
pkts.append(
struct.pack("<I", len(payload) - offset)[0:3]
+ struct.pack("<B", pktnr)
+ payload[offset:]
)
return pkts
@staticmethod
def get_header(pkt: bytes) -> Tuple[int, int, int]: # type: ignore[override]
"""Recover the header information from a packet."""
if len(pkt) < COMPRESSED_PACKET_HEADER_LENGTH:
raise ValueError("Can't recover header info from an incomplete packet")
compressed_pll, seqid, uncompressed_pll = (
struct.unpack("<I", pkt[0:3] + b"\x00")[0],
pkt[3],
struct.unpack("<I", pkt[4:7] + b"\x00")[0],
)
# compressed payload length, sequence id, uncompressed payload length
return compressed_pll, seqid, uncompressed_pll
def _set_next_compressed_pktnr(self, next_id: Optional[int] = None) -> None:
"""Set the given packet id, if any, else increment packet id."""
if next_id is None:
self._compressed_pktnr += 1
else:
self._compressed_pktnr = next_id
self._compressed_pktnr %= 256
async def _write_pkt(
self,
writer: StreamWriter,
address: str,
pkt: bytes,
) -> None:
"""Compress packet and write it to the comm channel."""
compressed_pkt = zlib.compress(pkt)
pkt = (
struct.pack("<I", len(compressed_pkt))[0:3]
+ struct.pack("<B", self._compressed_pktnr)
+ struct.pack("<I", len(pkt))[0:3]
+ compressed_pkt
)
return await super()._write_pkt(writer, address, pkt)
async def write(
self,
writer: StreamWriter,
address: str,
payload: bytes,
packet_number: Optional[int] = None,
compressed_packet_number: Optional[int] = None,
write_timeout: Optional[int] = None,
) -> None:
"""Send `payload` as compressed packets to the MySQL server.
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
broken down into packets.
"""
# Get next packet numbers.
self._set_next_pktnr(packet_number)
self._set_next_compressed_pktnr(compressed_packet_number)
try:
payload_prep = bytearray(b"").join(
self._prepare_packets(payload, self._pktnr)
)
if len(payload) >= MAX_PAYLOAD_LENGTH - PACKET_HEADER_LENGTH:
# Sending a MySQL payload of the size greater or equal to 2^24 - 5 via
# compression leads to at least one extra compressed packet WHY? let's say
# len(payload) is MAX_PAYLOAD_LENGTH - 3; when preparing the payload, a
# header of size PACKET_HEADER_LENGTH is pre-appended to the payload.
# This means that len(payload_prep) is
# MAX_PAYLOAD_LENGTH - 3 + PACKET_HEADER_LENGTH = MAX_PAYLOAD_LENGTH + 1
# surpassing the maximum allowed payload size per packet.
offset = 0
# Send several MySQL packets.
for _ in range(len(payload_prep) // MAX_PAYLOAD_LENGTH):
await asyncio.wait_for(
self._write_pkt(
writer,
address,
payload_prep[offset : offset + MAX_PAYLOAD_LENGTH],
),
write_timeout,
)
self._set_next_compressed_pktnr()
offset += MAX_PAYLOAD_LENGTH
await asyncio.wait_for(
self._write_pkt(writer, address, payload_prep[offset:]),
write_timeout,
)
else:
# Send one MySQL packet.
# For small packets it may be too costly to compress the packet.
# Usually payloads less than 50 bytes (MIN_COMPRESS_LENGTH) aren't
# compressed (see MySQL source code Documentation).
if len(payload) > MIN_COMPRESS_LENGTH:
# Perform compression.
await asyncio.wait_for(
self._write_pkt(writer, address, payload_prep), write_timeout
)
else:
# Skip compression.
await asyncio.wait_for(
super()._write_pkt(
writer,
address,
struct.pack("<I", len(payload_prep))[0:3]
+ struct.pack("<B", self._compressed_pktnr)
+ struct.pack("<I", 0)[0:3]
+ payload_prep,
),
write_timeout,
)
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
raise WriteTimeoutError(errno=3024) from err
async def _read_compressed_pkt(
self,
reader: asyncio.StreamReader,
compressed_pll: int,
read_timeout: Optional[int] = None,
) -> None:
"""Handle reading of a compressed packet."""
# compressed_pll stands for compressed payload length.
pkt = bytearray(
zlib.decompress(
await super()._read_chunk(reader, compressed_pll, read_timeout)
)
)
offset = 0
while offset < len(pkt):
# pll stands for payload length
pll = struct.unpack(
"<I", pkt[offset : offset + PACKET_HEADER_LENGTH - 1] + b"\x00"
)[0]
if PACKET_HEADER_LENGTH + pll > len(pkt) - offset:
# More bytes need to be consumed.
# Read the header of the next MySQL packet.
header = await super()._read_chunk(
reader, COMPRESSED_PACKET_HEADER_LENGTH, read_timeout
)
# compressed payload length, sequence id, uncompressed payload length.
(
compressed_pll,
self._compressed_pktnr,
uncompressed_pll,
) = self.get_header(header)
compressed_pkt = await super()._read_chunk(
reader, compressed_pll, read_timeout
)
# Recalling that if uncompressed payload length == 0, the packet comes
# in uncompressed, so no decompression is needed.
pkt += (
compressed_pkt
if uncompressed_pll == 0
else zlib.decompress(compressed_pkt)
)
self._queue_read.append(pkt[offset : offset + PACKET_HEADER_LENGTH + pll])
offset += PACKET_HEADER_LENGTH + pll
async def read(
self,
reader: asyncio.StreamReader,
address: str,
read_timeout: Optional[int] = None,
) -> bytearray:
"""Receive `one` or `several` packets from the MySQL server, enqueue them, and
return the packet at the head.
"""
if not self._queue_read:
try:
# Read the header of the next MySQL packet.
header = await super()._read_chunk(
reader, COMPRESSED_PACKET_HEADER_LENGTH, read_timeout
)
# compressed payload length, sequence id, uncompressed payload length
(
compressed_pll,
self._compressed_pktnr,
uncompressed_pll,
) = self.get_header(header)
if uncompressed_pll == 0:
# Packet is not compressed, so just store it.
self._queue_read.append(
await super()._read_chunk(reader, compressed_pll, read_timeout)
)
else:
# Packet comes in compressed, further action is needed.
await self._read_compressed_pkt(
reader, compressed_pll, read_timeout
)
except IOError as err:
raise OperationalError(
errno=2055, values=(address, _strioerror(err))
) from err
if not self._queue_read:
return None
pkt = self._queue_read.popleft()
self._pktnr = pkt[3]
return pkt
class MySQLSocket(ABC):
"""MySQL socket communication interface.
Examples:
Subclasses: network.MySQLTCPSocket and network.MySQLUnixSocket.
"""
def __init__(self) -> None:
"""Network layer where transactions are made with plain (uncompressed) packets
is enabled by default.
"""
self._reader: Optional[asyncio.StreamReader] = None
self._writer: Optional[StreamWriter] = None
self._connection_timeout: Optional[int] = None
self._address: Optional[str] = None
self._netbroker: NetworkBroker = NetworkBrokerPlain()
self._is_connected: bool = False
@property
def address(self) -> str:
"""Socket location."""
return self._address
@abstractmethod
async def open_connection(self, **kwargs: Any) -> None:
"""Open the socket."""
async def close_connection(self) -> None:
"""Close the connection."""
if self._writer:
try:
self._writer.close()
# Without transport.abort(), an error is raised when using SSL
if self._writer.transport is not None:
self._writer.transport.abort()
await self._writer.wait_closed()
except Exception as _: # pylint: disable=broad-exception-caught)
# we can ignore issues like ConnectionRefused or ConnectionAborted
# as these instances might popup if the connection was closed due to timeout issues
pass
self._is_connected = False
def is_connected(self) -> bool:
"""Check if the socket is connected.
Return:
bool: Returns `True` if the socket is connected to MySQL server.
"""
return self._is_connected
def set_connection_timeout(self, timeout: int) -> None:
"""Set the connection timeout."""
self._connection_timeout = timeout
def switch_to_compressed_mode(self) -> None:
"""Enable network layer where transactions are made with compressed packets."""
self._netbroker = NetworkBrokerCompressed()
async def switch_to_ssl(self, ssl_context: ssl.SSLContext) -> None:
"""Upgrade an existing stream-based connection to TLS.
The `start_tls()` method from `asyncio.streams.StreamWriter` is only available
in Python 3.11. This method is used as a workaround.
The MySQL TLS negotiation happens in the middle of the TCP connection.
Therefore, passing a socket to open connection will cause it to negotiate
TLS on an existing connection.
Args:
ssl_context: The SSL Context to be used.
Raises:
RuntimeError: If the transport does not expose the socket instance.
"""
# Ensure that self._writer is already created
assert self._writer is not None
socket = self._writer.transport.get_extra_info("socket")
if socket.family == 1: # socket.AF_UNIX
raise ProgrammingError("SSL is not supported when using Unix sockets")
await self._writer.start_tls(ssl_context)
async def write(
self,
payload: bytes,
packet_number: Optional[int] = None,
compressed_packet_number: Optional[int] = None,
write_timeout: Optional[int] = None,
) -> None:
"""Send packets to the MySQL server."""
await self._netbroker.write(
self._writer,
self.address,
payload,
packet_number=packet_number,
compressed_packet_number=compressed_packet_number,
write_timeout=write_timeout,
)
async def read(self, read_timeout: Optional[int] = None) -> bytearray:
"""Read packets from the MySQL server."""
return await self._netbroker.read(self._reader, self.address, read_timeout)
def build_ssl_context(
self,
ssl_ca: Optional[str] = None,
ssl_cert: Optional[str] = None,
ssl_key: Optional[str] = None,
ssl_verify_cert: Optional[bool] = False,
ssl_verify_identity: Optional[bool] = False,
tls_versions: Optional[List[str]] = [],
tls_cipher_suites: Optional[List[str]] = [],
) -> ssl.SSLContext:
"""Build a SSLContext."""
tls_version: Optional[str] = None
if not self._reader:
raise InterfaceError(errno=2048)
if ssl is None:
raise RuntimeError("Python installation has no SSL support")
try:
if tls_versions:
tls_versions.sort(reverse=True)
tls_version = tls_versions[0]
ssl_protocol = TLS_VERSIONS[tls_version]
context = ssl.SSLContext(ssl_protocol)
if tls_version == "TLSv1.3":
if "TLSv1.2" not in tls_versions:
context.options |= ssl.OP_NO_TLSv1_2
if "TLSv1.1" not in tls_versions:
context.options |= ssl.OP_NO_TLSv1_1
if "TLSv1" not in tls_versions:
context.options |= ssl.OP_NO_TLSv1
else:
context = ssl.create_default_context()
context.check_hostname = ssl_verify_identity
if ssl_verify_cert:
context.verify_mode = ssl.CERT_REQUIRED
elif ssl_verify_identity:
context.verify_mode = ssl.CERT_OPTIONAL
else:
context.verify_mode = ssl.CERT_NONE
context.load_default_certs()
if ssl_ca:
try:
context.load_verify_locations(ssl_ca)
except (IOError, ssl.SSLError) as err:
raise InterfaceError(f"Invalid CA Certificate: {err}") from err
if ssl_cert:
try:
context.load_cert_chain(ssl_cert, ssl_key)
except (IOError, ssl.SSLError) as err:
raise InterfaceError(f"Invalid Certificate/Key: {err}") from err
# TLSv1.3 ciphers cannot be disabled with `SSLContext.set_ciphers(...)`,
# see https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers.
if tls_cipher_suites and tls_version == "TLSv1.2":
context.set_ciphers(":".join(tls_cipher_suites))
return context
except NameError as err:
raise NotSupportedError("Python installation has no SSL support") from err
except (
IOError,
NotImplementedError,
ssl.CertificateError,
ssl.SSLError,
) as err:
raise InterfaceError(str(err)) from err
class MySQLTcpSocket(MySQLSocket):
"""MySQL socket class using TCP/IP.
Args:
host: MySQL host name.
port: MySQL port.
force_ipv6: Force IPv6 usage.
"""
def __init__(
self, host: str = "127.0.0.1", port: int = 3306, force_ipv6: bool = False
):
super().__init__()
self._host: str = host
self._port: int = port
self._force_ipv6: bool = force_ipv6
self._address: str = f"{host}:{port}"
async def open_connection(self, **kwargs: Any) -> None:
"""Open TCP/IP connection."""
self._reader, self._writer = await open_connection(
host=self._host, port=self._port, **kwargs
)
self._is_connected = True
class MySQLUnixSocket(MySQLSocket):
"""MySQL socket class using UNIX sockets.
Args:
unix_socket: UNIX socket file path.
"""
def __init__(self, unix_socket: str = "/tmp/mysql.sock"):
super().__init__()
self._address: str = unix_socket
async def open_connection(self, **kwargs: Any) -> None:
"""Open UNIX socket connection."""
(
self._reader,
self._writer,
) = await asyncio.open_unix_connection( # type: ignore[assignment]
path=self._address, **kwargs
)
self._is_connected = True

View file

@ -0,0 +1,162 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Base Authentication Plugin class."""
__all__ = ["MySQLAuthPlugin", "get_auth_plugin"]
import importlib
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional, Type
from mysql.connector.errors import NotSupportedError, ProgrammingError
from mysql.connector.logger import logger
if TYPE_CHECKING:
from ..network import MySQLSocket
DEFAULT_PLUGINS_PKG = "mysql.connector.aio.plugins"
class MySQLAuthPlugin(ABC):
"""Authorization plugin interface."""
def __init__(
self,
username: str,
password: str,
ssl_enabled: bool = False,
) -> None:
"""Constructor."""
self._username: str = "" if username is None else username
self._password: str = "" if password is None else password
self._ssl_enabled: bool = ssl_enabled
@property
def ssl_enabled(self) -> bool:
"""Signals whether or not SSL is enabled."""
return self._ssl_enabled
@property
@abstractmethod
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
@property
@abstractmethod
def name(self) -> str:
"""Plugin official name."""
@abstractmethod
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Make the client's authorization response.
Args:
auth_data: Authorization data.
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Client's authorization response.
"""
async def auth_more_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth more data` response.
Args:
sock: Pointer to the socket connection.
auth_data: Authentication method data (from a packet representing
an `auth more data` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth communication.
"""
raise NotImplementedError
@abstractmethod
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth communication.
"""
@lru_cache(maxsize=10, typed=False)
def get_auth_plugin(
plugin_name: str,
auth_plugin_class: Optional[str] = None,
) -> Type[MySQLAuthPlugin]:
"""Return authentication class based on plugin name
This function returns the class for the authentication plugin plugin_name.
The returned class is a subclass of BaseAuthPlugin.
Args:
plugin_name (str): Authentication plugin name.
auth_plugin_class (str): Authentication plugin class name.
Raises:
NotSupportedError: When plugin_name is not supported.
Returns:
Subclass of `MySQLAuthPlugin`.
"""
package = DEFAULT_PLUGINS_PKG
if plugin_name:
try:
logger.info("package: %s", package)
logger.info("plugin_name: %s", plugin_name)
plugin_module = importlib.import_module(f".{plugin_name}", package)
if not auth_plugin_class or not hasattr(plugin_module, auth_plugin_class):
auth_plugin_class = plugin_module.AUTHENTICATION_PLUGIN_CLASS
logger.info("AUTHENTICATION_PLUGIN_CLASS: %s", auth_plugin_class)
return getattr(plugin_module, auth_plugin_class)
except ModuleNotFoundError as err:
logger.warning("Requested Module was not found: %s", err)
except ValueError as err:
raise ProgrammingError(f"Invalid module name: {err}") from err
raise NotSupportedError(f"Authentication plugin '{plugin_name}' is not supported")

View file

@ -0,0 +1,577 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# mypy: disable-error-code="str-bytes-safe,misc"
"""Kerberos Authentication Plugin."""
import getpass
import os
import struct
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Tuple
from mysql.connector.errors import InterfaceError, ProgrammingError
from mysql.connector.logger import logger
from ..authentication import ERR_STATUS
if TYPE_CHECKING:
from ..network import MySQLSocket
try:
import gssapi
except ImportError:
gssapi = None
if os.name != "nt":
raise ProgrammingError(
"Module gssapi is required for GSSAPI authentication "
"mechanism but was not found. Unable to authenticate "
"with the server"
) from None
try:
import sspi
import sspicon
except ImportError:
sspi = None
sspicon = None
from . import MySQLAuthPlugin
AUTHENTICATION_PLUGIN_CLASS = (
"MySQLSSPIKerberosAuthPlugin" if os.name == "nt" else "MySQLKerberosAuthPlugin"
)
class MySQLBaseKerberosAuthPlugin(MySQLAuthPlugin):
"""Base class for the MySQL Kerberos authentication plugin."""
@property
def name(self) -> str:
"""Plugin official name."""
return "authentication_kerberos_client"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
@abstractmethod
def auth_continue(
self, tgt_auth_challenge: Optional[bytes]
) -> Tuple[Optional[bytes], bool]:
"""Continue with the Kerberos TGT service request.
With the TGT authentication service given response generate a TGT
service request. This method must be invoked sequentially (in a loop)
until the security context is completed and an empty response needs to
be send to acknowledge the server.
Args:
tgt_auth_challenge: the challenge for the negotiation.
Returns:
tuple (bytearray TGS service request,
bool True if context is completed otherwise False).
"""
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
logger.debug("# auth_data: %s", auth_data)
response = self.auth_response(auth_data, ignore_auth_data=False, **kwargs)
if response is None:
raise InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
packet = await sock.read()
logger.debug("# server response packet: %s", packet)
if packet != ERR_STATUS:
rcode_size = 5 # Reader size for the response status code
logger.debug("# Continue with GSSAPI authentication")
logger.debug("# Response header: %s", packet[: rcode_size + 1])
logger.debug("# Response size: %s", len(packet))
logger.debug("# Negotiate a service request")
complete = False
tries = 0
while not complete and tries < 5:
logger.debug("%s Attempt %s %s", "-" * 20, tries + 1, "-" * 20)
logger.debug("<< Server response: %s", packet)
logger.debug("# Response code: %s", packet[: rcode_size + 1])
token, complete = self.auth_continue(packet[rcode_size:])
if token:
await sock.write(token)
if complete:
break
packet = await sock.read()
logger.debug(">> Response to server: %s", token)
tries += 1
if not complete:
raise InterfaceError(
f"Unable to fulfill server request after {tries} "
f"attempts. Last server response: {packet}"
)
logger.debug(
"Last response from server: %s length: %d",
packet,
len(packet),
)
# Receive OK packet from server.
packet = await sock.read()
logger.debug("<< Ok packet from server: %s", packet)
return bytes(packet)
# pylint: disable=c-extension-no-member,no-member
class MySQLKerberosAuthPlugin(MySQLBaseKerberosAuthPlugin):
"""Implement the MySQL Kerberos authentication plugin."""
context: Optional[gssapi.SecurityContext] = None
@staticmethod
def get_user_from_credentials() -> str:
"""Get user from credentials without realm."""
try:
creds = gssapi.Credentials(usage="initiate")
user = str(creds.name)
if user.find("@") != -1:
user, _ = user.split("@", 1)
return user
except gssapi.raw.misc.GSSError:
return getpass.getuser()
@staticmethod
def get_store() -> dict:
"""Get a credentials store dictionary.
Returns:
dict: Credentials store dictionary with the krb5 ccache name.
Raises:
InterfaceError: If 'KRB5CCNAME' environment variable is empty.
"""
krb5ccname = os.environ.get(
"KRB5CCNAME",
(
f"/tmp/krb5cc_{os.getuid()}"
if os.name == "posix"
else Path("%TEMP%").joinpath("krb5cc")
),
)
if not krb5ccname:
raise InterfaceError(
"The 'KRB5CCNAME' environment variable is set to empty"
)
logger.debug("Using krb5 ccache name: FILE:%s", krb5ccname)
store = {b"ccache": f"FILE:{krb5ccname}".encode("utf-8")}
return store
def _acquire_cred_with_password(self, upn: str) -> gssapi.raw.creds.Creds:
"""Acquire and store credentials through provided password.
Args:
upn (str): User Principal Name.
Returns:
gssapi.raw.creds.Creds: GSSAPI credentials.
"""
logger.debug("Attempt to acquire credentials through provided password")
user = gssapi.Name(upn, gssapi.NameType.user)
password = self._password.encode("utf-8")
try:
acquire_cred_result = gssapi.raw.acquire_cred_with_password(
user, password, usage="initiate"
)
creds = acquire_cred_result.creds
gssapi.raw.store_cred_into(
self.get_store(),
creds=creds,
mech=gssapi.MechType.kerberos,
overwrite=True,
set_default=True,
)
except gssapi.raw.misc.GSSError as err:
raise ProgrammingError(
f"Unable to acquire credentials with the given password: {err}"
) from err
return creds
@staticmethod
def _parse_auth_data(packet: bytes) -> Tuple[str, str]:
"""Parse authentication data.
Get the SPN and REALM from the authentication data packet.
Format:
SPN string length two bytes <B1> <B2> +
SPN string +
UPN realm string length two bytes <B1> <B2> +
UPN realm string
Returns:
tuple: With 'spn' and 'realm'.
"""
spn_len = struct.unpack("<H", packet[:2])[0]
packet = packet[2:]
spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
packet = packet[spn_len:]
realm_len = struct.unpack("<H", packet[:2])[0]
realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
return spn.decode(), realm.decode()
def auth_response(
self, auth_data: Optional[bytes] = None, **kwargs: Any
) -> Optional[bytes]:
"""Prepare the first message to the server."""
spn = None
realm = None
if auth_data and not kwargs.get("ignore_auth_data", True):
try:
spn, realm = self._parse_auth_data(auth_data)
except struct.error as err:
raise InterruptedError(f"Invalid authentication data: {err}") from err
if spn is None:
return self._password.encode() + b"\x00"
upn = f"{self._username}@{realm}" if self._username else None
logger.debug("Service Principal: %s", spn)
logger.debug("Realm: %s", realm)
try:
# Attempt to retrieve credentials from cache file
creds: Any = gssapi.Credentials(usage="initiate")
creds_upn = str(creds.name)
logger.debug("Cached credentials found")
logger.debug("Cached credentials UPN: %s", creds_upn)
# Remove the realm from user
if creds_upn.find("@") != -1:
creds_user, creds_realm = creds_upn.split("@", 1)
else:
creds_user = creds_upn
creds_realm = None
upn = f"{self._username}@{realm}" if self._username else creds_upn
# The user from cached credentials matches with the given user?
if self._username and self._username != creds_user:
logger.debug(
"The user from cached credentials doesn't match with the "
"given user"
)
if self._password is not None:
creds = self._acquire_cred_with_password(upn)
if creds_realm and creds_realm != realm and self._password is not None:
creds = self._acquire_cred_with_password(upn)
except gssapi.raw.exceptions.ExpiredCredentialsError as err:
if upn and self._password is not None:
creds = self._acquire_cred_with_password(upn)
else:
raise InterfaceError(f"Credentials has expired: {err}") from err
except gssapi.raw.misc.GSSError as err:
if upn and self._password is not None:
creds = self._acquire_cred_with_password(upn)
else:
raise InterfaceError(
f"Unable to retrieve cached credentials error: {err}"
) from err
flags = (
gssapi.RequirementFlag.mutual_authentication,
gssapi.RequirementFlag.extended_error,
gssapi.RequirementFlag.delegate_to_peer,
)
name = gssapi.Name(spn, name_type=gssapi.NameType.kerberos_principal)
cname = name.canonicalize(gssapi.MechType.kerberos)
self.context = gssapi.SecurityContext(
name=cname, creds=creds, flags=sum(flags), usage="initiate"
)
try:
initial_client_token: Optional[bytes] = self.context.step()
except gssapi.raw.misc.GSSError as err:
raise InterfaceError(f"Unable to initiate security context: {err}") from err
logger.debug("Initial client token: %s", initial_client_token)
return initial_client_token
def auth_continue(
self, tgt_auth_challenge: Optional[bytes]
) -> Tuple[Optional[bytes], bool]:
"""Continue with the Kerberos TGT service request.
With the TGT authentication service given response generate a TGT
service request. This method must be invoked sequentially (in a loop)
until the security context is completed and an empty response needs to
be send to acknowledge the server.
Args:
tgt_auth_challenge: the challenge for the negotiation.
Returns:
tuple (bytearray TGS service request,
bool True if context is completed otherwise False).
"""
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
resp: Optional[bytes] = self.context.step(tgt_auth_challenge)
logger.debug("Context step response: %s", resp)
logger.debug("Context completed?: %s", self.context.complete)
return resp, self.context.complete
def auth_accept_close_handshake(self, message: bytes) -> bytes:
"""Accept handshake and generate closing handshake message for server.
This method verifies the server authenticity from the given message
and included signature and generates the closing handshake for the
server.
When this method is invoked the security context is already established
and the client and server can send GSSAPI formated secure messages.
To finish the authentication handshake the server sends a message
with the security layer availability and the maximum buffer size.
Since the connector only uses the GSSAPI authentication mechanism to
authenticate the user with the server, the server will verify clients
message signature and terminate the GSSAPI authentication and send two
messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
OK packet (that must be received after sent the returned message from
this method).
Args:
message: a wrapped gssapi message from the server.
Returns:
bytearray (closing handshake message to be send to the server).
"""
if not self.context.complete:
raise ProgrammingError("Security context is not completed")
logger.debug("Server message: %s", message)
logger.debug("GSSAPI flags in use: %s", self.context.actual_flags)
try:
unwraped = self.context.unwrap(message)
logger.debug("Unwraped: %s", unwraped)
except gssapi.raw.exceptions.BadMICError as err:
logger.debug("Unable to unwrap server message: %s", err)
raise InterfaceError(f"Unable to unwrap server message: {err}") from err
logger.debug("Unwrapped server message: %s", unwraped)
# The message contents for the clients closing message:
# - security level 1 byte, must be always 1.
# - conciliated buffer size 3 bytes, without importance as no
# further GSSAPI messages will be sends.
response = bytearray(b"\x01\x00\x00\00")
# Closing handshake must not be encrypted.
logger.debug("Message response: %s", response)
wraped = self.context.wrap(response, encrypt=False)
logger.debug(
"Wrapped message response: %s, length: %d",
wraped[0],
len(wraped[0]),
)
return wraped.message
class MySQLSSPIKerberosAuthPlugin(MySQLBaseKerberosAuthPlugin):
"""Implement the MySQL Kerberos authentication plugin with Windows SSPI"""
context: Any = None
clientauth: Any = None
@staticmethod
def _parse_auth_data(packet: bytes) -> Tuple[str, str]:
"""Parse authentication data.
Get the SPN and REALM from the authentication data packet.
Format:
SPN string length two bytes <B1> <B2> +
SPN string +
UPN realm string length two bytes <B1> <B2> +
UPN realm string
Returns:
tuple: With 'spn' and 'realm'.
"""
spn_len = struct.unpack("<H", packet[:2])[0]
packet = packet[2:]
spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
packet = packet[spn_len:]
realm_len = struct.unpack("<H", packet[:2])[0]
realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
return spn.decode(), realm.decode()
def auth_response(
self, auth_data: Optional[bytes] = None, **kwargs: Any
) -> Optional[bytes]:
"""Prepare the first message to the server.
Args:
kwargs:
ignore_auth_data (bool): if True, the provided auth data is ignored.
"""
logger.debug("auth_response for sspi")
spn = None
realm = None
if auth_data and not kwargs.get("ignore_auth_data", True):
try:
spn, realm = self._parse_auth_data(auth_data)
except struct.error as err:
raise InterruptedError(f"Invalid authentication data: {err}") from err
logger.debug("Service Principal: %s", spn)
logger.debug("Realm: %s", realm)
if sspicon is None or sspi is None:
raise ProgrammingError(
'Package "pywin32" (Python for Win32 (pywin32) extensions)'
" is not installed."
)
flags = (sspicon.ISC_REQ_MUTUAL_AUTH, sspicon.ISC_REQ_DELEGATE)
if self._username and self._password:
_auth_info = (self._username, realm, self._password)
else:
_auth_info = None
targetspn = spn
logger.debug("targetspn: %s", targetspn)
logger.debug("_auth_info is None: %s", _auth_info is None)
# The Security Support Provider Interface (SSPI) is an interface
# that allows us to choose from a set of SSPs available in the
# system; the idea of SSPI is to keep interface consistent no
# matter what back end (a.k.a., SSP) we choose.
# When using SSPI we should not use Kerberos directly as SSP,
# as remarked in [2], but we can use it indirectly via another
# SSP named Negotiate that acts as an application layer between
# SSPI and the other SSPs [1].
# Negotiate can select between Kerberos and NTLM on the fly;
# it chooses Kerberos unless it cannot be used by one of the
# systems involved in the authentication or the calling
# application did not provide sufficient information to use
# Kerberos.
# prefix: https://docs.microsoft.com/en-us/windows/win32/secauthn
# [1] prefix/microsoft-negotiate?source=recommendations
# [2] prefix/microsoft-kerberos?source=recommendations
self.clientauth = sspi.ClientAuth(
"Negotiate",
targetspn=targetspn,
auth_info=_auth_info,
scflags=sum(flags),
datarep=sspicon.SECURITY_NETWORK_DREP,
)
try:
data = None
err, out_buf = self.clientauth.authorize(data)
logger.debug("Context step err: %s", err)
logger.debug("Context step out_buf: %s", out_buf)
logger.debug("Context completed?: %s", self.clientauth.authenticated)
initial_client_token = out_buf[0].Buffer
logger.debug("pkg_info: %s", self.clientauth.pkg_info)
except Exception as err:
raise InterfaceError(f"Unable to initiate security context: {err}") from err
logger.debug("Initial client token: %s", initial_client_token)
return initial_client_token
def auth_continue(
self, tgt_auth_challenge: Optional[bytes]
) -> Tuple[Optional[bytes], bool]:
"""Continue with the Kerberos TGT service request.
With the TGT authentication service given response generate a TGT
service request. This method must be invoked sequentially (in a loop)
until the security context is completed and an empty response needs to
be send to acknowledge the server.
Args:
tgt_auth_challenge: the challenge for the negotiation.
Returns:
tuple (bytearray TGS service request,
bool True if context is completed otherwise False).
"""
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
err, out_buf = self.clientauth.authorize(tgt_auth_challenge)
logger.debug("Context step err: %s", err)
logger.debug("Context step out_buf: %s", out_buf)
resp = out_buf[0].Buffer
logger.debug("Context step resp: %s", resp)
logger.debug("Context completed?: %s", self.clientauth.authenticated)
return resp, self.clientauth.authenticated

View file

@ -0,0 +1,595 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""LDAP SASL Authentication Plugin."""
import hmac
from base64 import b64decode, b64encode
from hashlib import sha1, sha256
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from uuid import uuid4
from mysql.connector.authentication import ERR_STATUS
from mysql.connector.errors import InterfaceError, ProgrammingError
from mysql.connector.logger import logger
from mysql.connector.types import StrOrBytes
from mysql.connector.utils import (
normalize_unicode_string as norm_ustr,
validate_normalized_unicode_string as valid_norm,
)
if TYPE_CHECKING:
from ..network import MySQLSocket
try:
import gssapi
except ImportError:
raise ProgrammingError(
"Module gssapi is required for GSSAPI authentication "
"mechanism but was not found. Unable to authenticate "
"with the server"
) from None
from . import MySQLAuthPlugin
AUTHENTICATION_PLUGIN_CLASS = "MySQLLdapSaslPasswordAuthPlugin"
# pylint: disable=c-extension-no-member,no-member
class MySQLLdapSaslPasswordAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL ldap sasl authentication plugin.
The MySQL's ldap sasl authentication plugin support two authentication
methods SCRAM-SHA-1 and GSSAPI (using Kerberos). This implementation only
support SCRAM-SHA-1 and SCRAM-SHA-256.
SCRAM-SHA-1 amd SCRAM-SHA-256
This method requires 2 messages from client and 2 responses from
server.
The first message from client will be generated by prepare_password(),
after receive the response from the server, it is required that this
response is passed back to auth_continue() which will return the
second message from the client. After send this second message to the
server, the second server respond needs to be passed to auth_finalize()
to finish the authentication process.
"""
sasl_mechanisms: List[str] = ["SCRAM-SHA-1", "SCRAM-SHA-256", "GSSAPI"]
def_digest_mode: Callable = sha1
client_nonce: Optional[str] = None
client_salt: Any = None
server_salt: Optional[str] = None
krb_service_principal: Optional[str] = None
iterations: int = 0
server_auth_var: Optional[str] = None
target_name: Optional[gssapi.Name] = None
ctx: gssapi.SecurityContext = None
servers_first: Optional[str] = None
server_nonce: Optional[str] = None
@staticmethod
def _xor(bytes1: bytes, bytes2: bytes) -> bytes:
return bytes([b1 ^ b2 for b1, b2 in zip(bytes1, bytes2)])
def _hmac(self, password: bytes, salt: bytes) -> bytes:
digest_maker = hmac.new(password, salt, self.def_digest_mode)
return digest_maker.digest()
def _hi(self, password: str, salt: bytes, count: int) -> bytes:
"""Prepares Hi
Hi(password, salt, iterations) where Hi(p,s,i) is defined as
PBKDF2 (HMAC, p, s, i, output length of H).
"""
pw = password.encode()
hi = self._hmac(pw, salt + b"\x00\x00\x00\x01")
aux = hi
for _ in range(count - 1):
aux = self._hmac(pw, aux)
hi = self._xor(hi, aux)
return hi
@staticmethod
def _normalize(string: str) -> str:
norm_str = norm_ustr(string)
broken_rule = valid_norm(norm_str)
if broken_rule is not None:
raise InterfaceError(f"broken_rule: {broken_rule}")
return norm_str
def _first_message(self) -> bytes:
"""This method generates the first message to the server to start the
The client-first message consists of a gs2-header,
the desired username, and a randomly generated client nonce cnonce.
The first message from the server has the form:
b'n,a=<user_name>,n=<user_name>,r=<client_nonce>
Returns client's first message
"""
cfm_fprnat = "n,a={user_name},n={user_name},r={client_nonce}"
self.client_nonce = str(uuid4()).replace("-", "")
cfm: StrOrBytes = cfm_fprnat.format(
user_name=self._normalize(self._username),
client_nonce=self.client_nonce,
)
if isinstance(cfm, str):
cfm = cfm.encode("utf8")
return cfm
def _first_message_krb(self) -> Optional[bytes]:
"""Get a TGT Authentication request and initiates security context.
This method will contact the Kerberos KDC in order of obtain a TGT.
"""
user_name = gssapi.raw.names.import_name(
self._username.encode("utf8"), name_type=gssapi.NameType.user
)
# Use defaults store = {'ccache': 'FILE:/tmp/krb5cc_1000'}#,
# 'keytab':'/etc/some.keytab' }
# Attempt to retrieve credential from default cache file.
try:
cred: Any = gssapi.Credentials()
logger.debug(
"# Stored credentials found, if password was given it will be ignored."
)
try:
# validate credentials has not expired.
cred.lifetime
except gssapi.raw.exceptions.ExpiredCredentialsError as err:
logger.warning(" Credentials has expired: %s", err)
cred.acquire(user_name)
raise InterfaceError(f"Credentials has expired: {err}") from err
except gssapi.raw.misc.GSSError as err:
if not self._password:
raise InterfaceError(
f"Unable to retrieve stored credentials error: {err}"
) from err
try:
logger.debug("# Attempt to retrieve credentials with given password")
acquire_cred_result = gssapi.raw.acquire_cred_with_password(
user_name,
self._password.encode("utf8"),
usage="initiate",
)
cred = acquire_cred_result[0]
except gssapi.raw.misc.GSSError as err2:
raise ProgrammingError(
f"Unable to retrieve credentials with the given password: {err2}"
) from err
flags_l = (
gssapi.RequirementFlag.mutual_authentication,
gssapi.RequirementFlag.extended_error,
gssapi.RequirementFlag.delegate_to_peer,
)
if self.krb_service_principal:
service_principal = self.krb_service_principal
else:
service_principal = "ldap/ldapauth"
logger.debug("# service principal: %s", service_principal)
servk = gssapi.Name(
service_principal, name_type=gssapi.NameType.kerberos_principal
)
self.target_name = servk
self.ctx = gssapi.SecurityContext(
name=servk, creds=cred, flags=sum(flags_l), usage="initiate"
)
try:
# step() returns bytes | None, see documentation,
# so this method could return a NULL payload.
# ref: https://pythongssapi.github.io/<suffix>
# suffix: python-gssapi/latest/gssapi.html#gssapi.sec_contexts.SecurityContext
initial_client_token = self.ctx.step()
except gssapi.raw.misc.GSSError as err:
raise InterfaceError(f"Unable to initiate security context: {err}") from err
logger.debug("# initial client token: %s", initial_client_token)
return initial_client_token
def auth_continue_krb(
self, tgt_auth_challenge: Optional[bytes]
) -> Tuple[Optional[bytes], bool]:
"""Continue with the Kerberos TGT service request.
With the TGT authentication service given response generate a TGT
service request. This method must be invoked sequentially (in a loop)
until the security context is completed and an empty response needs to
be send to acknowledge the server.
Args:
tgt_auth_challenge the challenge for the negotiation.
Returns: tuple (bytearray TGS service request,
bool True if context is completed otherwise False).
"""
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
resp = self.ctx.step(tgt_auth_challenge)
logger.debug("# context step response: %s", resp)
logger.debug("# context completed?: %s", self.ctx.complete)
return resp, self.ctx.complete
def auth_accept_close_handshake(self, message: bytes) -> bytes:
"""Accept handshake and generate closing handshake message for server.
This method verifies the server authenticity from the given message
and included signature and generates the closing handshake for the
server.
When this method is invoked the security context is already established
and the client and server can send GSSAPI formated secure messages.
To finish the authentication handshake the server sends a message
with the security layer availability and the maximum buffer size.
Since the connector only uses the GSSAPI authentication mechanism to
authenticate the user with the server, the server will verify clients
message signature and terminate the GSSAPI authentication and send two
messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
OK packet (that must be received after sent the returned message from
this method).
Args:
message a wrapped hssapi message from the server.
Returns: bytearray closing handshake message to be send to the server.
"""
if not self.ctx.complete:
raise ProgrammingError("Security context is not completed.")
logger.debug("# servers message: %s", message)
logger.debug("# GSSAPI flags in use: %s", self.ctx.actual_flags)
try:
unwraped = self.ctx.unwrap(message)
logger.debug("# unwraped: %s", unwraped)
except gssapi.raw.exceptions.BadMICError as err:
raise InterfaceError(f"Unable to unwrap server message: {err}") from err
logger.debug("# unwrapped server message: %s", unwraped)
# The message contents for the clients closing message:
# - security level 1 byte, must be always 1.
# - conciliated buffer size 3 bytes, without importance as no
# further GSSAPI messages will be sends.
response = bytearray(b"\x01\x00\x00\00")
# Closing handshake must not be encrypted.
logger.debug("# message response: %s", response)
wraped = self.ctx.wrap(response, encrypt=False)
logger.debug(
"# wrapped message response: %s, length: %d",
wraped[0],
len(wraped[0]),
)
return wraped.message
def auth_response(
self,
auth_data: bytes,
**kwargs: Any,
) -> Optional[bytes]:
"""This method will prepare the fist message to the server.
Returns bytes to send to the server as the first message.
"""
# pylint: disable=attribute-defined-outside-init
self._auth_data = auth_data
auth_mechanism = self._auth_data.decode()
logger.debug("read_method_name_from_server: %s", auth_mechanism)
if auth_mechanism not in self.sasl_mechanisms:
auth_mechanisms = '", "'.join(self.sasl_mechanisms[:-1])
raise InterfaceError(
f'The sasl authentication method "{auth_mechanism}" requested '
f'from the server is not supported. Only "{auth_mechanisms}" '
f'and "{self.sasl_mechanisms[-1]}" are supported'
)
if b"GSSAPI" in self._auth_data:
return self._first_message_krb()
if self._auth_data == b"SCRAM-SHA-256":
self.def_digest_mode = sha256
return self._first_message()
def _second_message(self) -> bytes:
"""This method generates the second message to the server
Second message consist on the concatenation of the client and the
server nonce, and cproof.
c=<n,a=<user_name>>,r=<server_nonce>,p=<client_proof>
where:
<client_proof>: xor(<client_key>, <client_signature>)
<client_key>: hmac(salted_password, b"Client Key")
<client_signature>: hmac(<stored_key>, <auth_msg>)
<stored_key>: h(<client_key>)
<auth_msg>: <client_first_no_header>,<servers_first>,
c=<client_header>,r=<server_nonce>
<client_first_no_header>: n=<username>r=<client_nonce>
"""
if not self._auth_data:
raise InterfaceError("Missing authentication data (seed)")
passw = self._normalize(self._password)
salted_password = self._hi(passw, b64decode(self.server_salt), self.iterations)
logger.debug("salted_password: %s", b64encode(salted_password).decode())
client_key = self._hmac(salted_password, b"Client Key")
logger.debug("client_key: %s", b64encode(client_key).decode())
stored_key = self.def_digest_mode(client_key).digest()
logger.debug("stored_key: %s", b64encode(stored_key).decode())
server_key = self._hmac(salted_password, b"Server Key")
logger.debug("server_key: %s", b64encode(server_key).decode())
client_first_no_header = ",".join(
[
f"n={self._normalize(self._username)}",
f"r={self.client_nonce}",
]
)
logger.debug("client_first_no_header: %s", client_first_no_header)
client_header = b64encode(
f"n,a={self._normalize(self._username)},".encode()
).decode()
auth_msg = ",".join(
[
client_first_no_header,
self.servers_first,
f"c={client_header}",
f"r={self.server_nonce}",
]
)
logger.debug("auth_msg: %s", auth_msg)
client_signature = self._hmac(stored_key, auth_msg.encode())
logger.debug("client_signature: %s", b64encode(client_signature).decode())
client_proof = self._xor(client_key, client_signature)
logger.debug("client_proof: %s", b64encode(client_proof).decode())
self.server_auth_var = b64encode(
self._hmac(server_key, auth_msg.encode())
).decode()
logger.debug("server_auth_var: %s", self.server_auth_var)
msg = ",".join(
[
f"c={client_header}",
f"r={self.server_nonce}",
f"p={b64encode(client_proof).decode()}",
]
)
logger.debug("second_message: %s", msg)
return msg.encode()
def _validate_first_reponse(self, servers_first: bytes) -> None:
"""Validates first message from the server.
Extracts the server's salt and iterations from the servers 1st response.
First message from the server is in the form:
<server_salt>,i=<iterations>
"""
if not servers_first or not isinstance(servers_first, (bytearray, bytes)):
raise InterfaceError(f"Unexpected server message: {repr(servers_first)}")
try:
servers_first_str = servers_first.decode()
self.servers_first = servers_first_str
r_server_nonce, s_salt, i_counter = servers_first_str.split(",")
except ValueError:
raise InterfaceError(
f"Unexpected server message: {servers_first_str}"
) from None
if (
not r_server_nonce.startswith("r=")
or not s_salt.startswith("s=")
or not i_counter.startswith("i=")
):
raise InterfaceError(
f"Incomplete reponse from the server: {servers_first_str}"
)
if self.client_nonce in r_server_nonce:
self.server_nonce = r_server_nonce[2:]
logger.debug("server_nonce: %s", self.server_nonce)
else:
raise InterfaceError(
"Unable to authenticate response: response not well formed "
f"{servers_first_str}"
)
self.server_salt = s_salt[2:]
logger.debug(
"server_salt: %s length: %s",
self.server_salt,
len(self.server_salt),
)
try:
i_counter = i_counter[2:]
logger.debug("iterations: %s", i_counter)
self.iterations = int(i_counter)
except Exception as err:
raise InterfaceError(
f"Unable to authenticate: iterations not found {servers_first_str}"
) from err
def auth_continue(self, servers_first_response: bytes) -> bytes:
"""return the second message from the client.
Returns bytes to send to the server as the second message.
"""
self._validate_first_reponse(servers_first_response)
return self._second_message()
def _validate_second_reponse(self, servers_second: bytearray) -> bool:
"""Validates second message from the server.
The client and the server prove to each other they have the same Auth
variable.
The second message from the server consist of the server's proof:
server_proof = HMAC(<server_key>, <auth_msg>)
where:
<server_key>: hmac(<salted_password>, b"Server Key")
<auth_msg>: <client_first_no_header>,<servers_first>,
c=<client_header>,r=<server_nonce>
Our server_proof must be equal to the Auth variable send on this second
response.
"""
if (
not servers_second
or not isinstance(servers_second, bytearray)
or len(servers_second) <= 2
or not servers_second.startswith(b"v=")
):
raise InterfaceError("The server's proof is not well formated")
server_var = servers_second[2:].decode()
logger.debug("server auth variable: %s", server_var)
return self.server_auth_var == server_var
def auth_finalize(self, servers_second_response: bytearray) -> bool:
"""finalize the authentication process.
Raises InterfaceError if the ervers_second_response is invalid.
Returns True in successful authentication False otherwise.
"""
if not self._validate_second_reponse(servers_second_response):
raise InterfaceError(
"Authentication failed: Unable to proof server identity"
)
return True
@property
def name(self) -> str:
"""Plugin official name."""
return "authentication_ldap_sasl_client"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
logger.debug("# auth_data: %s", auth_data)
self.krb_service_principal = kwargs.get("krb_service_principal")
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
packet = await sock.read()
logger.debug("# server response packet: %s", packet)
if len(packet) >= 6 and packet[5] == 114 and packet[6] == 61: # 'r' and '='
# Continue with sasl authentication
dec_response = packet[5:]
cresponse = self.auth_continue(dec_response)
await sock.write(cresponse)
packet = await sock.read()
if packet[5] == 118 and packet[6] == 61: # 'v' and '='
if self.auth_finalize(packet[5:]):
# receive packed OK
packet = await sock.read()
elif auth_data == b"GSSAPI" and packet[4] != ERR_STATUS:
rcode_size = 5 # header size for the response status code.
logger.debug("# Continue with sasl GSSAPI authentication")
logger.debug("# response header: %s", packet[: rcode_size + 1])
logger.debug("# response size: %s", len(packet))
logger.debug("# Negotiate a service request")
complete = False
tries = 0 # To avoid a infinite loop attempt no more than feedback messages
while not complete and tries < 5:
logger.debug("%s Attempt %s %s", "-" * 20, tries + 1, "-" * 20)
logger.debug("<< server response: %s", packet)
logger.debug("# response code: %s", packet[: rcode_size + 1])
step, complete = self.auth_continue_krb(packet[rcode_size:])
logger.debug(" >> response to server: %s", step)
await sock.write(step or b"")
packet = await sock.read()
tries += 1
if not complete:
raise InterfaceError(
f"Unable to fulfill server request after {tries} "
f"attempts. Last server response: {packet}"
)
logger.debug(
" last GSSAPI response from server: %s length: %d",
packet,
len(packet),
)
last_step = self.auth_accept_close_handshake(packet[rcode_size:])
logger.debug(
" >> last response to server: %s length: %d",
last_step,
len(last_step),
)
await sock.write(last_step)
# Receive final handshake from server
packet = await sock.read()
logger.debug("<< final handshake from server: %s", packet)
# receive OK packet from server.
packet = await sock.read()
logger.debug("<< ok packet from server: %s", packet)
return bytes(packet)
# pylint: enable=c-extension-no-member,no-member

View file

@ -0,0 +1,234 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# mypy: disable-error-code="arg-type,union-attr,call-arg"
"""OCI Authentication Plugin."""
import json
import os
from base64 import b64encode
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
from mysql.connector import errors
from mysql.connector.logger import logger
if TYPE_CHECKING:
from ..network import MySQLSocket
try:
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
except ImportError:
raise errors.ProgrammingError("Package 'cryptography' is not installed") from None
try:
from oci import config, exceptions
except ImportError:
raise errors.ProgrammingError(
"Package 'oci' (Oracle Cloud Infrastructure Python SDK) is not installed"
) from None
from . import MySQLAuthPlugin
AUTHENTICATION_PLUGIN_CLASS = "MySQLOCIAuthPlugin"
OCI_SECURITY_TOKEN_MAX_SIZE = 10 * 1024 # In bytes
OCI_SECURITY_TOKEN_TOO_LARGE = "Ephemeral security token is too large (10KB max)"
OCI_SECURITY_TOKEN_FILE_NOT_AVAILABLE = (
"Ephemeral security token file ('security_token_file') could not be read"
)
OCI_PROFILE_MISSING_PROPERTIES = (
"OCI configuration file does not contain a 'fingerprint' or 'key_file' entry"
)
class MySQLOCIAuthPlugin(MySQLAuthPlugin):
"""Implement the MySQL OCI IAM authentication plugin."""
context: Any = None
oci_config_profile: str = "DEFAULT"
oci_config_file: str = config.DEFAULT_LOCATION
@staticmethod
def _prepare_auth_response(signature: bytes, oci_config: Dict[str, Any]) -> str:
"""Prepare client's authentication response
Prepares client's authentication response in JSON format
Args:
signature (bytes): server's nonce to be signed by client.
oci_config (dict): OCI configuration object.
Returns:
str: JSON string with the following format:
{"fingerprint": str, "signature": str, "token": base64.base64.base64}
Raises:
ProgrammingError: If the ephemeral security token file can't be open or the
token is too large.
"""
signature_64 = b64encode(signature)
auth_response = {
"fingerprint": oci_config["fingerprint"],
"signature": signature_64.decode(),
}
# The security token, if it exists, should be a JWT (JSON Web Token), consisted
# of a base64-encoded header, body, and signature, separated by '.',
# e.g. "Base64.Base64.Base64", stored in a file at the path specified by the
# security_token_file configuration property
if oci_config.get("security_token_file"):
try:
security_token_file = Path(oci_config["security_token_file"])
# Check if token exceeds the maximum size
if security_token_file.stat().st_size > OCI_SECURITY_TOKEN_MAX_SIZE:
raise errors.ProgrammingError(OCI_SECURITY_TOKEN_TOO_LARGE)
auth_response["token"] = security_token_file.read_text(encoding="utf-8")
except (OSError, UnicodeError) as err:
raise errors.ProgrammingError(
OCI_SECURITY_TOKEN_FILE_NOT_AVAILABLE
) from err
return json.dumps(auth_response, separators=(",", ":"))
@staticmethod
def _get_private_key(key_path: str) -> PRIVATE_KEY_TYPES:
"""Get the private_key form the given location"""
try:
with open(os.path.expanduser(key_path), "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(),
password=None,
)
except (TypeError, OSError, ValueError, UnsupportedAlgorithm) as err:
raise errors.ProgrammingError(
"An error occurred while reading the API_KEY from "
f'"{key_path}": {err}'
)
return private_key
def _get_valid_oci_config(self) -> Dict[str, Any]:
"""Get a valid OCI config from the given configuration file path"""
error_list = []
req_keys = {
"fingerprint": (lambda x: len(x) > 32),
"key_file": (lambda x: os.path.exists(os.path.expanduser(x))),
}
oci_config: Dict[str, Any] = {}
try:
# key_file is validated by oci.config if present
oci_config = config.from_file(
self.oci_config_file or config.DEFAULT_LOCATION,
self.oci_config_profile or "DEFAULT",
)
for req_key, req_value in req_keys.items():
try:
# Verify parameter in req_key is present and valid
if oci_config[req_key] and not req_value(oci_config[req_key]):
error_list.append(f'Parameter "{req_key}" is invalid')
except KeyError:
error_list.append(f"Does not contain parameter {req_key}")
except (
exceptions.ConfigFileNotFound,
exceptions.InvalidConfig,
exceptions.InvalidKeyFilePath,
exceptions.InvalidPrivateKey,
exceptions.ProfileNotFound,
) as err:
error_list.append(str(err))
# Raise errors if any
if error_list:
raise errors.ProgrammingError(
f"Invalid oci-config-file: {self.oci_config_file}. "
f"Errors found: {error_list}"
)
return oci_config
@property
def name(self) -> str:
"""Plugin official name."""
return "authentication_oci_client"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Prepare authentication string for the server."""
logger.debug("server nonce: %s, len %d", auth_data, len(auth_data))
oci_config = self._get_valid_oci_config()
private_key = self._get_private_key(oci_config["key_file"])
signature = private_key.sign(auth_data, padding.PKCS1v15(), hashes.SHA256())
auth_response = self._prepare_auth_response(signature, oci_config)
logger.debug("authentication response: %s", auth_response)
return auth_response.encode()
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
self.oci_config_file = kwargs.get("oci_config_file", "DEFAULT")
self.oci_config_profile = kwargs.get(
"oci_config_profile", config.DEFAULT_LOCATION
)
logger.debug("# oci configuration file path: %s", self.oci_config_file)
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise errors.InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
packet = await sock.read()
logger.debug("# server response packet: %s", packet)
return bytes(packet)

View file

@ -0,0 +1,172 @@
# Copyright (c) 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""OpenID Authentication Plugin."""
import re
from pathlib import Path
from typing import Any, List
from mysql.connector import errors, utils
from mysql.connector.aio.network import MySQLSocket
from mysql.connector.logger import logger
from . import MySQLAuthPlugin
AUTHENTICATION_PLUGIN_CLASS = "MySQLOpenIDConnectAuthPlugin"
OPENID_TOKEN_MAX_SIZE = 10 * 1024 # In bytes
class MySQLOpenIDConnectAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL OpenID Connect Authentication Plugin."""
_openid_capability_flag: bytes = utils.int1store(1)
@property
def name(self) -> str:
"""Plugin official name."""
return "authentication_openid_connect_client"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return True
@staticmethod
def _validate_openid_token(token: str) -> bool:
"""Helper method used to validate OpenID Connect token
The Token is represented as a JSON Web Token (JWT) consists of a
base64-encoded header, body, and signature, separated by '.' e.g.,
"Base64url.Base64url.Base64url". The First part of the token contains
the header, the second part contains payload and the third part contains
signature. These token parts should be Base64 URLSafe i.e., Token cannot
contain characters other than a-z, A-Z, 0-9 and special characters '-', '_'.
Args:
token (str): Base64url-encoded OpenID connect token fetched from
the file path passed via `openid_token_file` connection
argument.
Returns:
bool: Signal indicating whether the token is valid or not.
"""
header_payload_sig: List[str] = token.split(".")
if len(header_payload_sig) != 3:
# invalid structure
return False
urlsafe_pattern = re.compile("^[a-zA-Z0-9-_]*$")
return all(
(
len(token_part) and urlsafe_pattern.search(token_part) is not None
for token_part in header_payload_sig
)
)
def auth_response(self, auth_data: bytes, **kwargs: Any) -> bytes:
"""Prepares authentication string for the server.
Args:
auth_data: Authorization data.
kwargs: Custom configuration to be passed to the auth plugin
when invoked.
Returns:
packet: Client's authorization response.
The OpenID Connect authorization response follows the pattern :-
int<1> capability flag
string<lenenc> id token
Raises:
InterfaceError: If the connection is insecure or the OpenID Token is too large,
invalid or non-existent.
ProgrammingError: If the OpenID Token file could not be read.
"""
try:
# Check if the connection is secure
if self.requires_ssl and not self._ssl_enabled:
raise errors.InterfaceError(f"{self.name} requires SSL")
# Validate the file
token_file_path: str = kwargs.get("openid_token_file", None)
openid_token_file: Path = Path(token_file_path)
# Check if token exceeds the maximum size
if openid_token_file.stat().st_size > OPENID_TOKEN_MAX_SIZE:
raise errors.InterfaceError(
"The OpenID Connect token file size is too large (> 10KB)"
)
openid_token: str = openid_token_file.read_text(encoding="utf-8")
openid_token = openid_token.strip()
# Validate the JWT Token
if not self._validate_openid_token(openid_token):
raise errors.InterfaceError("The OpenID Connect Token is invalid")
# build the auth_response packet
auth_response: List[bytes] = [
self._openid_capability_flag,
utils.lc_int(len(openid_token)),
openid_token.encode(),
]
return b"".join(auth_response)
except (SyntaxError, TypeError, OSError, UnicodeError) as err:
raise errors.ProgrammingError(
"The OpenID Connect Token File (openid_token_file) could not be read"
) from err
async def auth_switch_response(
self, sock: MySQLSocket, auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
Raises:
InterfaceError: If a NULL auth response is received from auth_response method.
"""
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise errors.InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
packet = await sock.read()
logger.debug("# server response packet: %s", packet)
return bytes(packet)

View file

@ -0,0 +1,291 @@
# Copyright (c) 2023, 2025, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""WebAuthn Authentication Plugin."""
from typing import TYPE_CHECKING, Any, Callable, Optional
from mysql.connector import errors, utils
from ..logger import logger
from . import MySQLAuthPlugin
if TYPE_CHECKING:
from ..network import MySQLSocket
try:
from fido2.cbor import dump_bytes as cbor_dump_bytes
from fido2.client import Fido2Client, UserInteraction
from fido2.hid import CtapHidDevice
from fido2.webauthn import PublicKeyCredentialRequestOptions
except ImportError as import_err:
raise errors.ProgrammingError(
"Module fido2 is required for WebAuthn authentication mechanism but was "
"not found. Unable to authenticate with the server"
) from import_err
try:
from fido2.pcsc import CtapPcscDevice
CTAP_PCSC_DEVICE_AVAILABLE = True
except ModuleNotFoundError:
CTAP_PCSC_DEVICE_AVAILABLE = False
AUTHENTICATION_PLUGIN_CLASS = "MySQLWebAuthnAuthPlugin"
class ClientInteraction(UserInteraction):
"""Provides user interaction to the Client."""
def __init__(self, callback: Optional[Callable] = None):
self.callback = callback
self.msg = (
"Please insert FIDO device and perform gesture action for authentication "
"to complete."
)
def prompt_up(self) -> None:
"""Prompt message for the user interaction with the FIDO device."""
if self.callback is None:
print(self.msg)
else:
self.callback(self.msg)
class MySQLWebAuthnAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL WebAuthn authentication plugin."""
client: Optional[Fido2Client] = None
callback: Optional[Callable] = None
options: dict = {"rpId": None, "challenge": None, "allowCredentials": []}
@property
def name(self) -> str:
"""Plugin official name."""
return "authentication_webauthn_client"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
def get_assertion_response(
self, credential_id: Optional[bytearray] = None
) -> bytes:
"""Get assertion from authenticator and return the response.
Args:
credential_id (Optional[bytearray]): The credential ID.
Returns:
bytearray: The response packet with the data from the assertion.
"""
if self.client is None:
raise errors.InterfaceError("No WebAuthn client found")
if credential_id is not None:
# If credential_id is not None, it's because the FIDO device does not
# support resident keys and the credential_id was requested from the server
self.options["allowCredentials"] = [
{
"id": credential_id,
"type": "public-key",
}
]
# Get assertion from authenticator
assertion = self.client.get_assertion(
PublicKeyCredentialRequestOptions.from_dict(self.options)
)
number_of_assertions = len(assertion.get_assertions())
client_data_json = b""
# Build response packet
#
# Format:
# int<1> 0x02 (2) status tag
# int<lenenc> number of assertions length encoded number of assertions
# string authenticator data variable length raw binary string
# string signed challenge variable length raw binary string
# ...
# ...
# string authenticator data variable length raw binary string
# string signed challenge variable length raw binary string
# string ClientDataJSON variable length raw binary string
packet = utils.lc_int(2)
packet += utils.lc_int(number_of_assertions)
# Add authenticator data and signed challenge for each assertion
for i in range(number_of_assertions):
assertion_response = assertion.get_response(i)
# string<lenenc> authenticator_data
authenticator_data = cbor_dump_bytes(assertion_response.authenticator_data)
# string<lenenc> signed_challenge
signature = assertion_response.signature
packet += utils.lc_int(len(authenticator_data))
packet += authenticator_data
packet += utils.lc_int(len(signature))
packet += signature
# string<lenenc> client_data_json
client_data_json = assertion_response.client_data
packet += utils.lc_int(len(client_data_json))
packet += client_data_json
logger.debug("WebAuthn - payload response packet: %s", packet)
return packet
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Find authenticator device and check if supports resident keys.
It also creates a Fido2Client using the relying party ID from the server.
Raises:
InterfaceError: When the FIDO device is not found.
Returns:
bytes: 2 if the authenticator supports resident keys else 1.
"""
try:
packets, capability = utils.read_int(auth_data, 1)
challenge, rp_id = utils.read_lc_string_list(packets)
self.options["challenge"] = challenge
self.options["rpId"] = rp_id.decode()
logger.debug("WebAuthn - capability: %d", capability)
logger.debug("WebAuthn - challenge: %s", self.options["challenge"])
logger.debug("WebAuthn - relying party id: %s", self.options["rpId"])
except ValueError as err:
raise errors.InterfaceError(
"Unable to parse MySQL WebAuthn authentication data"
) from err
# Locate a device
device = next(CtapHidDevice.list_devices(), None)
if device is not None:
logger.debug("WebAuthn - Use USB HID channel")
elif CTAP_PCSC_DEVICE_AVAILABLE:
device = next(CtapPcscDevice.list_devices(), None) # type: ignore[arg-type]
if device is None:
raise errors.InterfaceError("No FIDO device found")
# Set up a FIDO 2 client using the origin relying party id
self.client = Fido2Client(
device,
f"https://{self.options['rpId']}",
user_interaction=ClientInteraction(self.callback),
)
if not self.client.info.options.get("rk"):
logger.debug("WebAuthn - Authenticator doesn't support resident keys")
return b"1"
logger.debug("WebAuthn - Authenticator with support for resident key found")
return b"2"
async def auth_more_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth more data` response.
Args:
sock: Pointer to the socket connection.
auth_data: Authentication method data (from a packet representing
an `auth more data` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
_, credential_id = utils.read_lc_string(auth_data)
response = self.get_assertion_response(credential_id)
logger.debug("WebAuthn - request: %s size: %s", response, len(response))
await sock.write(response)
pkt = bytes(await sock.read())
logger.debug("WebAuthn - server response packet: %s", pkt)
return pkt
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
webauth_callback = kwargs.get("webauthn_callback") or kwargs.get(
"fido_callback"
)
self.callback = (
utils.import_object(webauth_callback)
if isinstance(webauth_callback, str)
else webauth_callback
)
response = self.auth_response(auth_data)
credential_id = None
if response == b"1":
# Authenticator doesn't support resident keys, request credential_id
logger.debug("WebAuthn - request credential_id")
await sock.write(utils.lc_int(int(response)))
# return a packet representing an `auth more data` response
return bytes(await sock.read())
response = self.get_assertion_response(credential_id)
logger.debug("WebAuthn - request: %s size: %s", response, len(response))
await sock.write(response)
pkt = bytes(await sock.read())
logger.debug("WebAuthn - server response packet: %s", pkt)
return pkt

View file

@ -0,0 +1,160 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Caching SHA2 Password Authentication Plugin."""
import struct
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Optional
from mysql.connector.errors import InterfaceError
from mysql.connector.logger import logger
from . import MySQLAuthPlugin
if TYPE_CHECKING:
from ..network import MySQLSocket
AUTHENTICATION_PLUGIN_CLASS = "MySQLCachingSHA2PasswordAuthPlugin"
class MySQLCachingSHA2PasswordAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL caching_sha2_password authentication plugin
Note that encrypting using RSA is not supported since the Python
Standard Library does not provide this OpenSSL functionality.
"""
perform_full_authentication: int = 4
def _scramble(self, auth_data: bytes) -> bytes:
"""Return a scramble of the password using a Nonce sent by the
server.
The scramble is of the form:
XOR(SHA2(password), SHA2(SHA2(SHA2(password)), Nonce))
"""
if not auth_data:
raise InterfaceError("Missing authentication data (seed)")
if not self._password:
return b""
hash1 = sha256(self._password.encode()).digest()
hash2 = sha256()
hash2.update(sha256(hash1).digest())
hash2.update(auth_data)
hash2_digest = hash2.digest()
xored = [h1 ^ h2 for (h1, h2) in zip(hash1, hash2_digest)]
hash3 = struct.pack("32B", *xored)
return hash3
@property
def name(self) -> str:
"""Plugin official name."""
return "caching_sha2_password"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Make the client's authorization response.
Args:
auth_data: Authorization data.
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Client's authorization response.
"""
if not auth_data:
return None
if len(auth_data) > 1:
return self._scramble(auth_data)
if auth_data[0] == self.perform_full_authentication:
# return password as clear text.
return self._password.encode() + b"\x00"
return None
async def auth_more_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth more data` response.
Args:
sock: Pointer to the socket connection.
auth_data: Authentication method data (from a packet representing
an `auth more data` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
response = self.auth_response(auth_data, **kwargs)
if response:
await sock.write(response)
return bytes(await sock.read())
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
pkt = bytes(await sock.read())
logger.debug("# server response packet: %s", pkt)
return pkt

View file

@ -0,0 +1,105 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Clear Password Authentication Plugin."""
from typing import TYPE_CHECKING, Any, Optional
from mysql.connector import errors
from mysql.connector.logger import logger
from . import MySQLAuthPlugin
if TYPE_CHECKING:
from ..network import MySQLSocket
AUTHENTICATION_PLUGIN_CLASS = "MySQLClearPasswordAuthPlugin"
class MySQLClearPasswordAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL Clear Password authentication plugin"""
def _prepare_password(self) -> bytes:
"""Prepare and return password as as clear text.
Returns:
bytes: Prepared password.
"""
return self._password.encode() + b"\x00"
@property
def name(self) -> str:
"""Plugin official name."""
return "mysql_clear_password"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Return the prepared password to send to MySQL.
Raises:
InterfaceError: When SSL is required by not enabled.
Returns:
str: The prepared password.
"""
if self.requires_ssl and not self._ssl_enabled:
raise errors.InterfaceError(f"{self.name} requires SSL")
return self._prepare_password()
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise errors.InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
pkt = bytes(await sock.read())
logger.debug("# server response packet: %s", pkt)
return pkt

View file

@ -0,0 +1,121 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Native Password Authentication Plugin."""
import struct
from hashlib import sha1
from typing import TYPE_CHECKING, Any, Optional
from mysql.connector.errors import InterfaceError
from mysql.connector.logger import logger
from . import MySQLAuthPlugin
if TYPE_CHECKING:
from ..network import MySQLSocket
AUTHENTICATION_PLUGIN_CLASS = "MySQLNativePasswordAuthPlugin"
class MySQLNativePasswordAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL Native Password authentication plugin"""
def _prepare_password(self, auth_data: bytes) -> bytes:
"""Prepares and returns password as native MySQL 4.1+ password"""
if not auth_data:
raise InterfaceError("Missing authentication data (seed)")
if not self._password:
return b""
hash4 = None
try:
hash1 = sha1(self._password.encode()).digest()
hash2 = sha1(hash1).digest()
hash3 = sha1(auth_data + hash2).digest()
xored = [h1 ^ h3 for (h1, h3) in zip(hash1, hash3)]
hash4 = struct.pack("20B", *xored)
except (struct.error, TypeError) as err:
raise InterfaceError(f"Failed scrambling password; {err}") from err
return hash4
@property
def name(self) -> str:
"""Plugin official name."""
return "mysql_native_password"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return False
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Make the client's authorization response.
Args:
auth_data: Authorization data.
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Client's authorization response.
"""
return self._prepare_password(auth_data)
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
pkt = bytes(await sock.read())
logger.debug("# server response packet: %s", pkt)
return pkt

View file

@ -0,0 +1,109 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""SHA256 Password Authentication Plugin."""
from typing import TYPE_CHECKING, Any, Optional
from mysql.connector import errors
from mysql.connector.logger import logger
from . import MySQLAuthPlugin
if TYPE_CHECKING:
from ..network import MySQLSocket
AUTHENTICATION_PLUGIN_CLASS = "MySQLSHA256PasswordAuthPlugin"
class MySQLSHA256PasswordAuthPlugin(MySQLAuthPlugin):
"""Class implementing the MySQL SHA256 authentication plugin
Note that encrypting using RSA is not supported since the Python
Standard Library does not provide this OpenSSL functionality.
"""
def _prepare_password(self) -> bytes:
"""Prepare and return password as as clear text.
Returns:
password (bytes): Prepared password.
"""
return self._password.encode() + b"\x00"
@property
def name(self) -> str:
"""Plugin official name."""
return "sha256_password"
@property
def requires_ssl(self) -> bool:
"""Signals whether or not SSL is required."""
return True
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
"""Return the prepared password to send to MySQL.
Raises:
InterfaceError: When SSL is required by not enabled.
Returns:
str: The prepared password.
"""
if self.requires_ssl and not self.ssl_enabled:
raise errors.InterfaceError(f"{self.name} requires SSL")
return self._prepare_password()
async def auth_switch_response(
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
) -> bytes:
"""Handles server's `auth switch request` response.
Args:
sock: Pointer to the socket connection.
auth_data: Plugin provided data (extracted from a packet
representing an `auth switch request` response).
kwargs: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override the ones
defined in the auth plugin itself.
Returns:
packet: Last server's response after back-and-forth
communication.
"""
response = self.auth_response(auth_data, **kwargs)
if response is None:
raise errors.InterfaceError("Got a NULL auth response")
logger.debug("# request: %s size: %s", response, len(response))
await sock.write(response)
pkt = bytes(await sock.read())
logger.debug("# server response packet: %s", pkt)
return pkt

View file

@ -0,0 +1,325 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implements the MySQL Client/Server protocol."""
__all__ = ["MySQLProtocol"]
import struct
from typing import Any, Dict, List, Optional, Tuple
from ..constants import ClientFlag, ServerCmd
from ..errors import InterfaceError, ProgrammingError, get_exception
from ..logger import logger
from ..protocol import (
DEFAULT_CHARSET_ID,
DEFAULT_MAX_ALLOWED_PACKET,
MySQLProtocol as _MySQLProtocol,
)
from ..types import BinaryProtocolType, DescriptionType, EofPacketType, HandShakeType
from ..utils import lc_int, read_lc_string_list
from .network import MySQLSocket
from .plugins import MySQLAuthPlugin, get_auth_plugin
from .plugins.caching_sha2_password import MySQLCachingSHA2PasswordAuthPlugin
class MySQLProtocol(_MySQLProtocol):
"""Implements MySQL client/server protocol.
Create and parses MySQL packets.
"""
@staticmethod
def auth_plugin_first_response( # type: ignore[override]
auth_data: bytes,
username: str,
password: str,
auth_plugin: str,
auth_plugin_class: Optional[str] = None,
ssl_enabled: bool = False,
plugin_config: Optional[Dict[str, Any]] = None,
) -> Tuple[bytes, MySQLAuthPlugin]:
"""Prepare the first authentication response.
Args:
auth_data: Authorization data from initial handshake.
username: Account's username.
password: Account's password.
client_flags: Integer representing client capabilities flags.
auth_plugin: Authorization plugin name.
auth_plugin_class: Authorization plugin class (has higher precedence
than the authorization plugin name).
ssl_enabled: Whether SSL is enabled or not.
plugin_config: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override
the ones defined in the auth plugin itself.
Returns:
auth_response: Authorization plugin response.
auth_strategy: Authorization plugin instance created based
on the provided `auth_plugin` and `auth_plugin_class`
parameters.
Raises:
InterfaceError: If authentication fails or when got a NULL auth response.
"""
if not password and auth_plugin == "":
# return auth response and an arbitrary auth strategy
return b"\x00", MySQLCachingSHA2PasswordAuthPlugin(
username, password, ssl_enabled=ssl_enabled
)
if plugin_config is None:
plugin_config = {}
try:
auth_strategy = get_auth_plugin(auth_plugin, auth_plugin_class)(
username, password, ssl_enabled=ssl_enabled
)
auth_response = auth_strategy.auth_response(auth_data, **plugin_config)
except (TypeError, InterfaceError) as err:
raise InterfaceError(f"Failed authentication: {err}") from err
if auth_response is None:
raise InterfaceError(
"Got NULL auth response while authenticating with "
f"plugin {auth_strategy.name}"
)
auth_response = lc_int(len(auth_response)) + auth_response
return auth_response, auth_strategy
@staticmethod
def make_auth( # type: ignore[override]
handshake: HandShakeType,
username: str,
password: str,
database: Optional[str] = None,
charset: int = DEFAULT_CHARSET_ID,
client_flags: int = 0,
max_allowed_packet: int = DEFAULT_MAX_ALLOWED_PACKET,
auth_plugin: Optional[str] = None,
auth_plugin_class: Optional[str] = None,
conn_attrs: Optional[Dict[str, str]] = None,
is_change_user_request: bool = False,
ssl_enabled: bool = False,
plugin_config: Optional[Dict[str, Any]] = None,
) -> Tuple[bytes, MySQLAuthPlugin]:
"""Make a MySQL Authentication packet.
Args:
handshake: Initial handshake.
username: Account's username.
password: Account's password.
database: Initial database name for the connection
charset: Client charset (see [2]), only the lower 8-bits.
client_flags: Integer representing client capabilities flags.
max_allowed_packet: Maximum packet size.
auth_plugin: Authorization plugin name.
auth_plugin_class: Authorization plugin class (has higher precedence
than the authorization plugin name).
conn_attrs: Connection attributes.
is_change_user_request: Whether is a `change user request` operation or not.
ssl_enabled: Whether SSL is enabled or not.
plugin_config: Custom configuration to be passed to the auth plugin
when invoked. The parameters defined here will override
the one defined in the auth plugin itself.
Returns:
handshake_response: Handshake response as per [1].
auth_strategy: Authorization plugin instance created based
on the provided `auth_plugin` and `auth_plugin_class`.
Raises:
ProgrammingError: Handshake misses authentication info.
References:
[1]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
page_protocol_connection_phase_packets_protocol_handshake_response.html
[2]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
page_protocol_basic_character_set.html#a_protocol_character_set
"""
b_username = username.encode()
response_payload = []
if is_change_user_request:
logger.debug("Got a `change user` request")
logger.debug("Starting authorization phase")
if handshake is None:
raise ProgrammingError("Got a NULL handshake") from None
if handshake.get("auth_data") is None:
raise ProgrammingError("Handshake misses authentication info") from None
try:
auth_plugin = auth_plugin or handshake["auth_plugin"] # type: ignore[assignment]
except (TypeError, KeyError) as err:
raise ProgrammingError(
f"Handshake misses authentication plugin info ({err})"
) from None
logger.debug("The provided initial strategy is %s", auth_plugin)
if is_change_user_request:
response_payload.append(
struct.pack(
f"<B{len(b_username)}sx",
ServerCmd.CHANGE_USER,
b_username,
)
)
else:
filler = "x" * 23
response_payload.append(
struct.pack(
f"<IIB{filler}{len(b_username)}sx",
client_flags,
max_allowed_packet,
charset,
b_username,
)
)
# auth plugin response
auth_response, auth_strategy = MySQLProtocol.auth_plugin_first_response(
auth_data=handshake["auth_data"], # type: ignore[arg-type]
username=username,
password=password,
auth_plugin=auth_plugin,
auth_plugin_class=auth_plugin_class,
ssl_enabled=ssl_enabled,
plugin_config=plugin_config,
)
response_payload.append(auth_response)
# database name
response_payload.append(MySQLProtocol.connect_with_db(client_flags, database))
# charset
if is_change_user_request:
response_payload.append(struct.pack("<H", charset))
# plugin name
if client_flags & ClientFlag.PLUGIN_AUTH:
response_payload.append(auth_plugin.encode() + b"\x00")
# connection attributes
if (client_flags & ClientFlag.CONNECT_ARGS) and conn_attrs is not None:
response_payload.append(MySQLProtocol.make_conn_attrs(conn_attrs))
return b"".join(response_payload), auth_strategy
# pylint: disable=invalid-overridden-method
async def read_binary_result( # type: ignore[override]
self,
sock: MySQLSocket,
columns: List[DescriptionType],
count: int = 1,
charset: str = "utf-8",
read_timeout: Optional[int] = None,
) -> Tuple[
List[Tuple[BinaryProtocolType, ...]],
Optional[EofPacketType],
]:
"""Read MySQL binary protocol result.
Reads all or given number of binary resultset rows from the socket.
"""
rows = []
eof = None
values = None
i = 0
while True:
if eof or i == count:
break
packet = await sock.read(read_timeout)
if packet[4] == 254:
eof = self.parse_eof(packet)
values = None
elif packet[4] == 0:
eof = None
values = self._parse_binary_values(columns, packet[5:], charset)
if eof is None and values is not None:
rows.append(values)
elif eof is None and values is None:
raise get_exception(packet)
i += 1
return (rows, eof)
# pylint: disable=invalid-overridden-method
async def read_text_result( # type: ignore[override]
self,
sock: MySQLSocket,
version: Tuple[int, ...],
count: int = 1,
read_timeout: Optional[int] = None,
) -> Tuple[
List[Tuple[Optional[bytes], ...]],
Optional[EofPacketType],
]:
"""Read MySQL text result.
Reads all or given number of rows from the socket.
Returns a tuple with 2 elements: a list with all rows and
the EOF packet.
"""
# Keep unused 'version' for API backward compatibility
_ = version
rows = []
eof = None
rowdata = None
i = 0
while True:
if eof or i == count:
break
packet = await sock.read(read_timeout)
if packet.startswith(b"\xff\xff\xff"):
datas = [packet[4:]]
packet = await sock.read(read_timeout)
while packet.startswith(b"\xff\xff\xff"):
datas.append(packet[4:])
packet = await sock.read(read_timeout)
datas.append(packet[4:])
rowdata = read_lc_string_list(b"".join(datas))
elif packet[4] == 254 and packet[0] < 7:
eof = self.parse_eof(packet)
rowdata = None
else:
eof = None
rowdata = read_lc_string_list(bytes(packet[4:]))
if eof is None and rowdata is not None:
rows.append(rowdata)
elif eof is None and rowdata is None:
raise get_exception(packet)
i += 1
return rows, eof

View file

@ -0,0 +1,155 @@
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0, as
# published by the Free Software Foundation.
#
# This program is designed to work with certain software (including
# but not limited to OpenSSL) that is licensed under separate terms,
# as designated in a particular file or component or in included license
# documentation. The authors of MySQL hereby grant you an
# additional permission to link the program and your derivative works
# with the separately licensed software that they have either included with
# the program or referenced in the documentation.
#
# Without limiting anything contained in the foregoing, this file,
# which is part of MySQL Connector/Python, is also subject to the
# Universal FOSS Exception, version 1.0, a copy of which can be found at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
# mypy: disable-error-code="attr-defined"
# pylint: disable=protected-access
"""Utilities."""
__all__ = ["to_thread", "open_connection"]
import asyncio
import contextvars
import functools
try:
import ssl
except ImportError:
ssl = None
from typing import TYPE_CHECKING, Any, Callable, Tuple
if TYPE_CHECKING:
from mysql.connector.aio.abstracts import MySQLConnectionAbstract
__all__.append("StreamWriter")
class StreamReaderProtocol(asyncio.StreamReaderProtocol):
"""Extends asyncio.streams.StreamReaderProtocol for adding start_tls().
The ``start_tls()`` is based on ``asyncio.streams.StreamWriter`` introduced
in Python 3.11. It provides the same functionality for older Python versions.
"""
def _replace_writer(self, writer: asyncio.StreamWriter) -> None:
"""Replace stream writer.
Args:
writer: Stream Writer.
"""
transport = writer.transport
self._stream_writer = writer
self._transport = transport
self._over_ssl = transport.get_extra_info("sslcontext") is not None
class StreamWriter(asyncio.streams.StreamWriter):
"""Extends asyncio.streams.StreamWriter for adding start_tls().
The ``start_tls()`` is based on ``asyncio.streams.StreamWriter`` introduced
in Python 3.11. It provides the same functionality for older Python versions.
"""
async def start_tls(
self,
ssl_context: ssl.SSLContext,
*,
server_hostname: str = None,
ssl_handshake_timeout: int = None,
) -> None:
"""Upgrade an existing stream-based connection to TLS.
Args:
ssl_context: Configured SSL context.
server_hostname: Server host name.
ssl_handshake_timeout: SSL handshake timeout.
"""
server_side = self._protocol._client_connected_cb is not None
protocol = self._protocol
await self.drain()
new_transport = await self._loop.start_tls(
# pylint: disable=access-member-before-definition
self._transport, # type: ignore[has-type]
protocol,
ssl_context,
server_side=server_side,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)
self._transport = ( # pylint: disable=attribute-defined-outside-init
new_transport
)
protocol._replace_writer(self)
async def open_connection(
host: str = None, port: int = None, *, limit: int = 2**16, **kwds: Any
) -> Tuple[asyncio.StreamReader, StreamWriter]:
"""A wrapper for create_connection() returning a (reader, writer) pair.
This function is based on ``asyncio.streams.open_connection`` and adds a custom
stream reader.
MySQL expects TLS negotiation to happen in the middle of a TCP connection, not at
the start.
This function in conjunction with ``_StreamReaderProtocol`` and ``_StreamWriter``
allows the TLS negotiation on an existing connection.
Args:
host: Server host name.
port: Server port.
limit: The buffer size limit used by the returned ``StreamReader`` instance.
By default the limit is set to 64 KiB.
Returns:
tuple: Returns a pair of reader and writer objects that are instances of
``StreamReader`` and ``StreamWriter`` classes.
"""
loop = asyncio.get_running_loop()
reader = asyncio.streams.StreamReader(limit=limit, loop=loop)
protocol = StreamReaderProtocol(reader, loop=loop)
transport, _ = await loop.create_connection(lambda: protocol, host, port, **kwds)
writer = StreamWriter(transport, protocol, reader, loop)
return reader, writer
async def to_thread(func: Callable, *args: Any, **kwargs: Any) -> asyncio.Future:
"""Asynchronously run function ``func`` in a separate thread.
This function is based on ``asyncio.to_thread()`` introduced in Python 3.9, which
provides the same functionality for older Python versions.
Returns:
coroutine: A coroutine that can be awaited to get the eventual result of
``func``.
"""
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)