Format files with Ruff (#17643)
Some checks are pending
Build docker images / build (push) Waiting to run
Deploy the documentation / Calculate variables for GitHub Pages deployment (push) Waiting to run
Deploy the documentation / GitHub Pages (push) Blocked by required conditions
Build release artifacts / Calculate list of debian distros (push) Waiting to run
Build release artifacts / Build .deb packages (push) Blocked by required conditions
Build release artifacts / Build wheels on ${{ matrix.os }} for ${{ matrix.arch }} (aarch64, ${{ startsWith(github.ref, 'refs/pull/') }}, ubuntu-20.04) (push) Waiting to run
Build release artifacts / Build wheels on ${{ matrix.os }} for ${{ matrix.arch }} (x86_64, ${{ startsWith(github.ref, 'refs/pull/') }}, macos-12) (push) Waiting to run
Build release artifacts / Build wheels on ${{ matrix.os }} for ${{ matrix.arch }} (x86_64, ${{ startsWith(github.ref, 'refs/pull/') }}, ubuntu-20.04) (push) Waiting to run
Build release artifacts / Build sdist (push) Waiting to run
Build release artifacts / Attach assets to release (push) Blocked by required conditions
Tests / check-schema-delta (push) Blocked by required conditions
Tests / check-lockfile (push) Waiting to run
Tests / lint (push) Blocked by required conditions
Tests / changes (push) Waiting to run
Tests / check-sampleconfig (push) Blocked by required conditions
Tests / Typechecking (push) Blocked by required conditions
Tests / lint-crlf (push) Waiting to run
Tests / lint-newsfile (push) Waiting to run
Tests / lint-pydantic (push) Blocked by required conditions
Tests / lint-clippy (push) Blocked by required conditions
Tests / lint-clippy-nightly (push) Blocked by required conditions
Tests / lint-rustfmt (push) Blocked by required conditions
Tests / lint-readme (push) Blocked by required conditions
Tests / linting-done (push) Blocked by required conditions
Tests / calculate-test-jobs (push) Blocked by required conditions
Tests / trial (push) Blocked by required conditions
Tests / trial-olddeps (push) Blocked by required conditions
Tests / trial-pypy (all, pypy-3.8) (push) Blocked by required conditions
Tests / sytest (push) Blocked by required conditions
Tests / export-data (push) Blocked by required conditions
Tests / portdb (11, 3.8) (push) Blocked by required conditions
Tests / portdb (15, 3.11) (push) Blocked by required conditions
Tests / complement (monolith, Postgres) (push) Blocked by required conditions
Tests / complement (monolith, SQLite) (push) Blocked by required conditions
Tests / complement (workers, Postgres) (push) Blocked by required conditions
Tests / cargo-test (push) Blocked by required conditions
Tests / cargo-bench (push) Blocked by required conditions
Tests / tests-done (push) Blocked by required conditions

I thought ruff check would also format, but it doesn't.

This runs ruff format in CI and dev scripts. The first commit is just a
run of `ruff format .` in the root directory.
This commit is contained in:
Quentin Gliech 2024-09-02 13:39:04 +02:00 committed by GitHub
parent 709b7363fe
commit 7d52ce7d4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
152 changed files with 526 additions and 492 deletions

View file

@ -29,10 +29,14 @@ jobs:
with: with:
install-project: "false" install-project: "false"
- name: Run ruff - name: Run ruff check
continue-on-error: true continue-on-error: true
run: poetry run ruff check --fix . run: poetry run ruff check --fix .
- name: Run ruff format
continue-on-error: true
run: poetry run ruff format --quiet .
- run: cargo clippy --all-features --fix -- -D warnings - run: cargo clippy --all-features --fix -- -D warnings
continue-on-error: true continue-on-error: true

View file

@ -131,9 +131,12 @@ jobs:
with: with:
install-project: "false" install-project: "false"
- name: Check style - name: Run ruff check
run: poetry run ruff check --output-format=github . run: poetry run ruff check --output-format=github .
- name: Run ruff format
run: poetry run ruff format --check .
lint-mypy: lint-mypy:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: Typechecking name: Typechecking

1
changelog.d/17643.misc Normal file
View file

@ -0,0 +1 @@
Replace `isort` and `black with `ruff`.

View file

@ -21,7 +21,8 @@
# #
# #
""" Starts a synapse client console. """ """Starts a synapse client console."""
import argparse import argparse
import binascii import binascii
import cmd import cmd

View file

@ -31,6 +31,7 @@ Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. Se
until then, this script is a best effort to stop us from introducing type coersion bugs until then, this script is a best effort to stop us from introducing type coersion bugs
(like the infamous stringy power levels fixed in room version 10). (like the infamous stringy power levels fixed in room version 10).
""" """
import argparse import argparse
import contextlib import contextlib
import functools import functools

View file

@ -109,6 +109,9 @@ set -x
# --quiet suppresses the update check. # --quiet suppresses the update check.
ruff check --quiet --fix "${files[@]}" ruff check --quiet --fix "${files[@]}"
# Reformat Python code.
ruff format --quiet "${files[@]}"
# Catch any common programming mistakes in Rust code. # Catch any common programming mistakes in Rust code.
# #
# --bins, --examples, --lib, --tests combined explicitly disable checking # --bins, --examples, --lib, --tests combined explicitly disable checking

View file

@ -20,8 +20,7 @@
# #
# #
"""An interactive script for doing a release. See `cli()` below. """An interactive script for doing a release. See `cli()` below."""
"""
import glob import glob
import json import json

View file

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains *incomplete* type hints for txredisapi. """Contains *incomplete* type hints for txredisapi."""
"""
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol from twisted.internet import protocol

View file

@ -20,8 +20,7 @@
# #
# #
""" This is an implementation of a Matrix homeserver. """This is an implementation of a Matrix homeserver."""
"""
import os import os
import sys import sys

View file

@ -171,7 +171,7 @@ def elide_http_methods_if_unconflicting(
""" """
def paths_to_methods_dict( def paths_to_methods_dict(
methods_and_paths: Iterable[Tuple[str, str]] methods_and_paths: Iterable[Tuple[str, str]],
) -> Dict[str, Set[str]]: ) -> Dict[str, Set[str]]:
""" """
Given (method, path) pairs, produces a dict from path to set of methods Given (method, path) pairs, produces a dict from path to set of methods
@ -201,7 +201,7 @@ def elide_http_methods_if_unconflicting(
def simplify_path_regexes( def simplify_path_regexes(
registrations: Dict[Tuple[str, str], EndpointDescription] registrations: Dict[Tuple[str, str], EndpointDescription],
) -> Dict[Tuple[str, str], EndpointDescription]: ) -> Dict[Tuple[str, str], EndpointDescription]:
""" """
Simplify all the path regexes for the dict of endpoint descriptions, Simplify all the path regexes for the dict of endpoint descriptions,

View file

@ -40,6 +40,7 @@ from synapse.storage.engines import create_engine
class ReviewConfig(RootConfig): class ReviewConfig(RootConfig):
"A config class that just pulls out the database config" "A config class that just pulls out the database config"
config_classes = [DatabaseConfig] config_classes = [DatabaseConfig]
@ -160,7 +161,11 @@ def main() -> None:
with make_conn(database_config, engine, "review_recent_signups") as db_conn: with make_conn(database_config, engine, "review_recent_signups") as db_conn:
# This generates a type of Cursor, not LoggingTransaction. # This generates a type of Cursor, not LoggingTransaction.
user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type] user_infos = get_recent_users(
db_conn.cursor(),
since_ms, # type: ignore[arg-type]
exclude_users_with_appservice,
)
for user_info in user_infos: for user_info in user_infos:
if exclude_users_with_email and user_info.emails: if exclude_users_with_email and user_info.emails:

View file

@ -717,9 +717,7 @@ class Porter:
return return
# Check if all background updates are done, abort if not. # Check if all background updates are done, abort if not.
updates_complete = ( updates_complete = await self.sqlite_store.db_pool.updates.has_completed_background_updates()
await self.sqlite_store.db_pool.updates.has_completed_background_updates()
)
if not updates_complete: if not updates_complete:
end_error = ( end_error = (
"Pending background updates exist in the SQLite3 database." "Pending background updates exist in the SQLite3 database."
@ -1095,10 +1093,10 @@ class Porter:
return done, remaining + done return done, remaining + done
async def _setup_state_group_id_seq(self) -> None: async def _setup_state_group_id_seq(self) -> None:
curr_id: Optional[int] = ( curr_id: Optional[
await self.sqlite_store.db_pool.simple_select_one_onecol( int
table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ] = await self.sqlite_store.db_pool.simple_select_one_onecol(
) table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
) )
if not curr_id: if not curr_id:
@ -1186,13 +1184,13 @@ class Porter:
) )
async def _setup_auth_chain_sequence(self) -> None: async def _setup_auth_chain_sequence(self) -> None:
curr_chain_id: Optional[int] = ( curr_chain_id: Optional[
await self.sqlite_store.db_pool.simple_select_one_onecol( int
table="event_auth_chains", ] = await self.sqlite_store.db_pool.simple_select_one_onecol(
keyvalues={}, table="event_auth_chains",
retcol="MAX(chain_id)", keyvalues={},
allow_none=True, retcol="MAX(chain_id)",
) allow_none=True,
) )
def r(txn: LoggingTransaction) -> None: def r(txn: LoggingTransaction) -> None:

View file

@ -19,7 +19,8 @@
# #
# #
"""Contains the URL paths to prefix various aspects of the server with. """ """Contains the URL paths to prefix various aspects of the server with."""
import hmac import hmac
from hashlib import sha256 from hashlib import sha256
from urllib.parse import urlencode from urllib.parse import urlencode

View file

@ -54,6 +54,7 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,

View file

@ -200,16 +200,13 @@ class KeyConfig(Config):
) )
form_secret = 'form_secret: "%s"' % random_string_with_symbols(50) form_secret = 'form_secret: "%s"' % random_string_with_symbols(50)
return ( return """\
"""\
%(macaroon_secret_key)s %(macaroon_secret_key)s
%(form_secret)s %(form_secret)s
signing_key_path: "%(base_key_name)s.signing.key" signing_key_path: "%(base_key_name)s.signing.key"
trusted_key_servers: trusted_key_servers:
- server_name: "matrix.org" - server_name: "matrix.org"
""" """ % locals()
% locals()
)
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]: def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
"""Read the signing keys in the given path. """Read the signing keys in the given path.
@ -249,7 +246,9 @@ class KeyConfig(Config):
if is_signing_algorithm_supported(key_id): if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"] key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(key_id, key_bytes) # type: ignore[assignment] verify_key: "VerifyKeyWithExpiry" = decode_verify_key_bytes(
key_id, key_bytes
) # type: ignore[assignment]
verify_key.expired = key_data["expired_ts"] verify_key.expired = key_data["expired_ts"]
keys[key_id] = verify_key keys[key_id] = verify_key
else: else:

View file

@ -157,12 +157,9 @@ class LoggingConfig(Config):
self, config_dir_path: str, server_name: str, **kwargs: Any self, config_dir_path: str, server_name: str, **kwargs: Any
) -> str: ) -> str:
log_config = os.path.join(config_dir_path, server_name + ".log.config") log_config = os.path.join(config_dir_path, server_name + ".log.config")
return ( return """\
"""\
log_config: "%(log_config)s" log_config: "%(log_config)s"
""" """ % locals()
% locals()
)
def read_arguments(self, args: argparse.Namespace) -> None: def read_arguments(self, args: argparse.Namespace) -> None:
if args.no_redirect_stdio is not None: if args.no_redirect_stdio is not None:

View file

@ -828,13 +828,10 @@ class ServerConfig(Config):
).lstrip() ).lstrip()
if not unsecure_listeners: if not unsecure_listeners:
unsecure_http_bindings = ( unsecure_http_bindings = """- port: %(unsecure_port)s
"""- port: %(unsecure_port)s
tls: false tls: false
type: http type: http
x_forwarded: true""" x_forwarded: true""" % locals()
% locals()
)
if not open_private_ports: if not open_private_ports:
unsecure_http_bindings += ( unsecure_http_bindings += (
@ -853,16 +850,13 @@ class ServerConfig(Config):
if not secure_listeners: if not secure_listeners:
secure_http_bindings = "" secure_http_bindings = ""
return ( return """\
"""\
server_name: "%(server_name)s" server_name: "%(server_name)s"
pid_file: %(pid_file)s pid_file: %(pid_file)s
listeners: listeners:
%(secure_http_bindings)s %(secure_http_bindings)s
%(unsecure_http_bindings)s %(unsecure_http_bindings)s
""" """ % locals()
% locals()
)
def read_arguments(self, args: argparse.Namespace) -> None: def read_arguments(self, args: argparse.Namespace) -> None:
if args.manhole is not None: if args.manhole is not None:

View file

@ -328,10 +328,11 @@ class WorkerConfig(Config):
) )
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently # type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
self.instance_map: Dict[ self.instance_map: Dict[str, InstanceLocationConfig] = (
str, InstanceLocationConfig parse_and_validate_mapping(
] = parse_and_validate_mapping( instance_map,
instance_map, InstanceLocationConfig # type: ignore[arg-type] InstanceLocationConfig, # type: ignore[arg-type]
)
) )
# Map from type of streams to source, c.f. WriterLocations. # Map from type of streams to source, c.f. WriterLocations.

View file

@ -887,7 +887,8 @@ def _check_power_levels(
raise SynapseError(400, f"{v!r} must be an integer.") raise SynapseError(400, f"{v!r} must be an integer.")
if k in {"events", "notifications", "users"}: if k in {"events", "notifications", "users"}:
if not isinstance(v, collections.abc.Mapping) or not all( if not isinstance(v, collections.abc.Mapping) or not all(
type(v) is int for v in v.values() # noqa: E721 type(v) is int
for v in v.values() # noqa: E721
): ):
raise SynapseError( raise SynapseError(
400, 400,

View file

@ -80,7 +80,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
# All methods that the module provides should be async, but this wasn't enforced # All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed # in the old module system, so we wrap them if needed
def async_wrapper( def async_wrapper(
f: Optional[Callable[P, R]] f: Optional[Callable[P, R]],
) -> Optional[Callable[P, Awaitable[R]]]: ) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this # f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None. # case we don't want to register a callback at all so we return None.

View file

@ -504,7 +504,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
def _encode_state_group_delta( def _encode_state_group_delta(
state_group_delta: Dict[Tuple[int, int], StateMap[str]] state_group_delta: Dict[Tuple[int, int], StateMap[str]],
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]: ) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
if not state_group_delta: if not state_group_delta:
return [] return []
@ -517,7 +517,7 @@ def _encode_state_group_delta(
def _decode_state_group_delta( def _decode_state_group_delta(
input: List[Tuple[int, int, List[Tuple[str, str, str]]]] input: List[Tuple[int, int, List[Tuple[str, str, str]]]],
) -> Dict[Tuple[int, int], StateMap[str]]: ) -> Dict[Tuple[int, int], StateMap[str]]:
if not input: if not input:
return {} return {}
@ -544,7 +544,7 @@ def _encode_state_dict(
def _decode_state_dict( def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]] input: Optional[List[Tuple[str, str, str]]],
) -> Optional[StateMap[str]]: ) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above""" """Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None: if input is None:

View file

@ -19,5 +19,4 @@
# #
# #
""" This package includes all the federation specific logic. """This package includes all the federation specific logic."""
"""

View file

@ -20,7 +20,7 @@
# #
# #
""" This module contains all the persistence actions done by the federation """This module contains all the persistence actions done by the federation
package. package.
These actions are mostly only used by the :py:mod:`.replication` module. These actions are mostly only used by the :py:mod:`.replication` module.

View file

@ -859,7 +859,6 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet):
request: SynapseRequest, request: SynapseRequest,
media_id: str, media_id: str,
) -> None: ) -> None:
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True) height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")

View file

@ -19,7 +19,7 @@
# #
# #
""" Defines the JSON structure of the protocol units used by the server to """Defines the JSON structure of the protocol units used by the server to
server protocol. server protocol.
""" """

View file

@ -118,10 +118,10 @@ class AccountHandler:
} }
if self._use_account_validity_in_account_status: if self._use_account_validity_in_account_status:
status["org.matrix.expired"] = ( status[
await self._account_validity_handler.is_user_expired( "org.matrix.expired"
user_id.to_string() ] = await self._account_validity_handler.is_user_expired(
) user_id.to_string()
) )
return status return status

View file

@ -197,14 +197,15 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most # events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything. # efficient method perhaps but it does guarantee we get everything.
while True: while True:
events, _ = ( (
await self._store.paginate_room_events_by_topological_ordering( events,
room_id=room_id, _,
from_key=from_key, ) = await self._store.paginate_room_events_by_topological_ordering(
to_key=to_key, room_id=room_id,
limit=100, from_key=from_key,
direction=Direction.FORWARDS, to_key=to_key,
) limit=100,
direction=Direction.FORWARDS,
) )
if not events: if not events:
break break

View file

@ -166,8 +166,7 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
if "country" not in identifier or ( if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number" # The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility. # field. Accept both for backwards compatibility.
"phone" not in identifier "phone" not in identifier and "number" not in identifier
and "number" not in identifier
): ):
raise SynapseError( raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM 400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM

View file

@ -265,9 +265,9 @@ class DirectoryHandler:
async def get_association(self, room_alias: RoomAlias) -> JsonDict: async def get_association(self, room_alias: RoomAlias) -> JsonDict:
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result: Optional[RoomAliasMapping] = ( result: Optional[
await self.get_association_from_room_alias(room_alias) RoomAliasMapping
) ] = await self.get_association_from_room_alias(room_alias)
if result: if result:
room_id = result.room_id room_id = result.room_id
@ -512,11 +512,9 @@ class DirectoryHandler:
raise SynapseError(403, "Not allowed to publish room") raise SynapseError(403, "Not allowed to publish room")
# Check if publishing is blocked by a third party module # Check if publishing is blocked by a third party module
allowed_by_third_party_rules = ( allowed_by_third_party_rules = await (
await ( self._third_party_event_rules.check_visibility_can_be_modified(
self._third_party_event_rules.check_visibility_can_be_modified( room_id, visibility
room_id, visibility
)
) )
) )
if not allowed_by_third_party_rules: if not allowed_by_third_party_rules:

View file

@ -1001,11 +1001,11 @@ class FederationHandler:
) )
if include_auth_user_id: if include_auth_user_id:
event_content[EventContentFields.AUTHORISING_USER] = ( event_content[
await self._event_auth_handler.get_user_which_could_invite( EventContentFields.AUTHORISING_USER
room_id, ] = await self._event_auth_handler.get_user_which_could_invite(
state_ids, room_id,
) state_ids,
) )
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(

View file

@ -21,6 +21,7 @@
# #
"""Utilities for interacting with Identity Servers""" """Utilities for interacting with Identity Servers"""
import logging import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple

View file

@ -1225,10 +1225,9 @@ class EventCreationHandler:
) )
if prev_event_ids is not None: if prev_event_ids is not None:
assert ( assert len(prev_event_ids) <= 10, (
len(prev_event_ids) <= 10 "Attempting to create an event with %i prev_events"
), "Attempting to create an event with %i prev_events" % ( % (len(prev_event_ids),)
len(prev_event_ids),
) )
else: else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)

View file

@ -507,15 +507,16 @@ class PaginationHandler:
# Initially fetch the events from the database. With any luck, we can return # Initially fetch the events from the database. With any luck, we can return
# these without blocking on backfill (handled below). # these without blocking on backfill (handled below).
events, next_key = ( (
await self.store.paginate_room_events_by_topological_ordering( events,
room_id=room_id, next_key,
from_key=from_token.room_key, ) = await self.store.paginate_room_events_by_topological_ordering(
to_key=to_room_key, room_id=room_id,
direction=pagin_config.direction, from_key=from_token.room_key,
limit=pagin_config.limit, to_key=to_room_key,
event_filter=event_filter, direction=pagin_config.direction,
) limit=pagin_config.limit,
event_filter=event_filter,
) )
if pagin_config.direction == Direction.BACKWARDS: if pagin_config.direction == Direction.BACKWARDS:
@ -584,15 +585,16 @@ class PaginationHandler:
# If we did backfill something, refetch the events from the database to # If we did backfill something, refetch the events from the database to
# catch anything new that might have been added since we last fetched. # catch anything new that might have been added since we last fetched.
if did_backfill: if did_backfill:
events, next_key = ( (
await self.store.paginate_room_events_by_topological_ordering( events,
room_id=room_id, next_key,
from_key=from_token.room_key, ) = await self.store.paginate_room_events_by_topological_ordering(
to_key=to_room_key, room_id=room_id,
direction=pagin_config.direction, from_key=from_token.room_key,
limit=pagin_config.limit, to_key=to_room_key,
event_filter=event_filter, direction=pagin_config.direction,
) limit=pagin_config.limit,
event_filter=event_filter,
) )
else: else:
# Otherwise, we can backfill in the background for eventual # Otherwise, we can backfill in the background for eventual

View file

@ -71,6 +71,7 @@ user state; this device follows the normal timeout logic (see above) and will
automatically be replaced with any information from currently available devices. automatically be replaced with any information from currently available devices.
""" """
import abc import abc
import contextlib import contextlib
import itertools import itertools
@ -493,9 +494,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# The number of ongoing syncs on this process, by (user ID, device ID). # The number of ongoing syncs on this process, by (user ID, device ID).
# Empty if _presence_enabled is false. # Empty if _presence_enabled is false.
self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( self._user_device_to_num_current_syncs: Dict[
{} Tuple[str, Optional[str]], int
) ] = {}
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id() self.instance_id = hs.get_instance_id()
@ -818,9 +819,9 @@ class PresenceHandler(BasePresenceHandler):
# Keeps track of the number of *ongoing* syncs on this process. While # Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline. # this is non zero a user will never go offline.
self._user_device_to_num_current_syncs: Dict[Tuple[str, Optional[str]], int] = ( self._user_device_to_num_current_syncs: Dict[
{} Tuple[str, Optional[str]], int
) ] = {}
# Keeps track of the number of *ongoing* syncs on other processes. # Keeps track of the number of *ongoing* syncs on other processes.
# #

View file

@ -351,9 +351,9 @@ class ProfileHandler:
server_name = host server_name = host
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
media_info: Optional[Union[LocalMedia, RemoteMedia]] = ( media_info: Optional[
await self.store.get_local_media(media_id) Union[LocalMedia, RemoteMedia]
) ] = await self.store.get_local_media(media_id)
else: else:
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)

View file

@ -188,13 +188,13 @@ class RelationsHandler:
if include_original_event: if include_original_event:
# Do not bundle aggregations when retrieving the original event because # Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it. # we want the content before relations are applied to it.
return_value["original_event"] = ( return_value[
await self._event_serializer.serialize_event( "original_event"
event, ] = await self._event_serializer.serialize_event(
now, event,
bundle_aggregations=None, now,
config=serialize_options, bundle_aggregations=None,
) config=serialize_options,
) )
if next_token: if next_token:

View file

@ -20,6 +20,7 @@
# #
"""Contains functions for performing actions on rooms.""" """Contains functions for performing actions on rooms."""
import itertools import itertools
import logging import logging
import math import math
@ -900,11 +901,9 @@ class RoomCreationHandler:
) )
# Check whether this visibility value is blocked by a third party module # Check whether this visibility value is blocked by a third party module
allowed_by_third_party_rules = ( allowed_by_third_party_rules = await (
await ( self._third_party_event_rules.check_visibility_can_be_modified(
self._third_party_event_rules.check_visibility_can_be_modified( room_id, visibility
room_id, visibility
)
) )
) )
if not allowed_by_third_party_rules: if not allowed_by_third_party_rules:

View file

@ -1302,11 +1302,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If this is going to be a local join, additional information must # If this is going to be a local join, additional information must
# be included in the event content in order to efficiently validate # be included in the event content in order to efficiently validate
# the event. # the event.
content[EventContentFields.AUTHORISING_USER] = ( content[
await self.event_auth_handler.get_user_which_could_invite( EventContentFields.AUTHORISING_USER
room_id, ] = await self.event_auth_handler.get_user_which_could_invite(
state_before_join, room_id,
) state_before_join,
) )
return False, [] return False, []
@ -1415,9 +1415,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if requester is not None: if requester is not None:
sender = UserID.from_string(event.sender) sender = UserID.from_string(event.sender)
assert ( assert sender == requester.user, (
sender == requester.user "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
), "Sender (%s) must be same as requester (%s)" % (sender, requester.user) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = types.create_requester(target_user) requester = types.create_requester(target_user)

View file

@ -423,9 +423,9 @@ class SearchHandler:
} }
if search_result.room_groups and "room_id" in group_keys: if search_result.room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})[ rooms_cat_res.setdefault("groups", {})["room_id"] = (
"room_id" search_result.room_groups
] = search_result.room_groups )
if sender_group and "sender" in group_keys: if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group rooms_cat_res.setdefault("groups", {})["sender"] = sender_group

View file

@ -587,9 +587,7 @@ class SlidingSyncHandler:
Membership.LEAVE, Membership.LEAVE,
Membership.BAN, Membership.BAN,
): ):
to_bound = ( to_bound = room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
)
timeline_from_bound = from_bound timeline_from_bound = from_bound
if ignore_timeline_bound: if ignore_timeline_bound:

View file

@ -386,9 +386,9 @@ class SlidingSyncExtensionHandler:
if have_push_rules_changed: if have_push_rules_changed:
global_account_data_map = dict(global_account_data_map) global_account_data_map = dict(global_account_data_map)
# TODO: This should take into account the `from_token` and `to_token` # TODO: This should take into account the `from_token` and `to_token`
global_account_data_map[AccountDataTypes.PUSH_RULES] = ( global_account_data_map[
await self.push_rules_handler.push_rules_for_user(sync_config.user) AccountDataTypes.PUSH_RULES
) ] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
else: else:
# TODO: This should take into account the `to_token` # TODO: This should take into account the `to_token`
all_global_account_data = await self.store.get_global_account_data_for_user( all_global_account_data = await self.store.get_global_account_data_for_user(
@ -397,9 +397,9 @@ class SlidingSyncExtensionHandler:
global_account_data_map = dict(all_global_account_data) global_account_data_map = dict(all_global_account_data)
# TODO: This should take into account the `to_token` # TODO: This should take into account the `to_token`
global_account_data_map[AccountDataTypes.PUSH_RULES] = ( global_account_data_map[
await self.push_rules_handler.push_rules_for_user(sync_config.user) AccountDataTypes.PUSH_RULES
) ] = await self.push_rules_handler.push_rules_for_user(sync_config.user)
# Fetch room account data # Fetch room account data
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {} account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}

View file

@ -293,10 +293,11 @@ class SlidingSyncRoomLists:
is_encrypted=is_encrypted, is_encrypted=is_encrypted,
) )
newly_joined_room_ids, newly_left_room_map = ( (
await self._get_newly_joined_and_left_rooms( newly_joined_room_ids,
user_id, from_token=from_token, to_token=to_token newly_left_room_map,
) ) = await self._get_newly_joined_and_left_rooms(
user_id, from_token=from_token, to_token=to_token
) )
dm_room_ids = await self._get_dm_rooms_for_user(user_id) dm_room_ids = await self._get_dm_rooms_for_user(user_id)
@ -958,10 +959,11 @@ class SlidingSyncRoomLists:
else: else:
rooms_for_user[room_id] = change_room_for_user rooms_for_user[room_id] = change_room_for_user
newly_joined_room_ids, newly_left_room_ids = ( (
await self._get_newly_joined_and_left_rooms( newly_joined_room_ids,
user_id, to_token=to_token, from_token=from_token newly_left_room_ids,
) ) = await self._get_newly_joined_and_left_rooms(
user_id, to_token=to_token, from_token=from_token
) )
dm_room_ids = await self._get_dm_rooms_for_user(user_id) dm_room_ids = await self._get_dm_rooms_for_user(user_id)

View file

@ -183,10 +183,7 @@ class JoinedSyncResult:
to tell if room needs to be part of the sync result. to tell if room needs to be part of the sync result.
""" """
return bool( return bool(
self.timeline self.timeline or self.state or self.ephemeral or self.account_data
or self.state
or self.ephemeral
or self.account_data
# nb the notification count does not, er, count: if there's nothing # nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it. # else in the result, we don't need to send it.
) )
@ -575,10 +572,10 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state: if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling # we are going to return immediately, so don't bother calling
# notifier.wait_for_events. # notifier.wait_for_events.
result: Union[SyncResult, E2eeSyncResult] = ( result: Union[
await self.current_sync_for_user( SyncResult, E2eeSyncResult
sync_config, sync_version, since_token, full_state=full_state ] = await self.current_sync_for_user(
) sync_config, sync_version, since_token, full_state=full_state
) )
else: else:
# Otherwise, we wait for something to happen and report it to the user. # Otherwise, we wait for something to happen and report it to the user.
@ -673,10 +670,10 @@ class SyncHandler:
# Go through the `/sync` v2 path # Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2: if sync_version == SyncVersion.SYNC_V2:
sync_result: Union[SyncResult, E2eeSyncResult] = ( sync_result: Union[
await self.generate_sync_result( SyncResult, E2eeSyncResult
sync_config, since_token, full_state ] = await self.generate_sync_result(
) sync_config, since_token, full_state
) )
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path # Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_version == SyncVersion.E2EE_SYNC: elif sync_version == SyncVersion.E2EE_SYNC:
@ -1488,13 +1485,16 @@ class SyncHandler:
# timeline here. The caller will then dedupe any redundant # timeline here. The caller will then dedupe any redundant
# ones. # ones.
state_ids = await self._state_storage_controller.get_state_ids_for_event( state_ids = (
batch.events[0].event_id, await self._state_storage_controller.get_state_ids_for_event(
# we only want members! batch.events[0].event_id,
state_filter=StateFilter.from_types( # we only want members!
(EventTypes.Member, member) for member in members_to_fetch state_filter=StateFilter.from_types(
), (EventTypes.Member, member)
await_full_state=False, for member in members_to_fetch
),
await_full_state=False,
)
) )
return state_ids return state_ids
@ -2166,18 +2166,18 @@ class SyncHandler:
if push_rules_changed: if push_rules_changed:
global_account_data = dict(global_account_data) global_account_data = dict(global_account_data)
global_account_data[AccountDataTypes.PUSH_RULES] = ( global_account_data[
await self._push_rules_handler.push_rules_for_user(sync_config.user) AccountDataTypes.PUSH_RULES
) ] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
else: else:
all_global_account_data = await self.store.get_global_account_data_for_user( all_global_account_data = await self.store.get_global_account_data_for_user(
user_id user_id
) )
global_account_data = dict(all_global_account_data) global_account_data = dict(all_global_account_data)
global_account_data[AccountDataTypes.PUSH_RULES] = ( global_account_data[
await self._push_rules_handler.push_rules_for_user(sync_config.user) AccountDataTypes.PUSH_RULES
) ] = await self._push_rules_handler.push_rules_for_user(sync_config.user)
account_data_for_user = ( account_data_for_user = (
await sync_config.filter_collection.filter_global_account_data( await sync_config.filter_collection.filter_global_account_data(

View file

@ -183,7 +183,7 @@ class WorkerLocksHandler:
return return
def _wake_all_locks( def _wake_all_locks(
locks: Collection[Union[WaitingLock, WaitingMultiLock]] locks: Collection[Union[WaitingLock, WaitingMultiLock]],
) -> None: ) -> None:
for lock in locks: for lock in locks:
deferred = lock.deferred deferred = lock.deferred

View file

@ -1313,6 +1313,5 @@ def is_unknown_endpoint(
) )
) or ( ) or (
# Older Synapses returned a 400 error. # Older Synapses returned a 400 error.
e.code == 400 e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED
and synapse_error.errcode == Codes.UNRECOGNIZED
) )

View file

@ -233,7 +233,7 @@ def return_html_error(
def wrap_async_request_handler( def wrap_async_request_handler(
h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]] h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]],
) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]: ) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing. """Wraps an async request handler so that it calls request.processing.

View file

@ -22,6 +22,7 @@
""" """
Log formatters that output terse JSON. Log formatters that output terse JSON.
""" """
import json import json
import logging import logging

View file

@ -20,7 +20,7 @@
# #
# #
""" Thread-local-alike tracking of log contexts within synapse """Thread-local-alike tracking of log contexts within synapse
This module provides objects and utilities for tracking contexts through This module provides objects and utilities for tracking contexts through
synapse code, so that log lines can include a request identifier, and so that synapse code, so that log lines can include a request identifier, and so that
@ -29,6 +29,7 @@ them.
See doc/log_contexts.rst for details on how this works. See doc/log_contexts.rst for details on how this works.
""" """
import logging import logging
import threading import threading
import typing import typing
@ -751,7 +752,7 @@ def preserve_fn(
f: Union[ f: Union[
Callable[P, R], Callable[P, R],
Callable[P, Awaitable[R]], Callable[P, Awaitable[R]],
] ],
) -> Callable[P, "defer.Deferred[R]"]: ) -> Callable[P, "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background""" """Function decorator which wraps the function with run_in_background"""

View file

@ -169,6 +169,7 @@ Gotchas
than one caller? Will all of those calling functions have be in a context than one caller? Will all of those calling functions have be in a context
with an active span? with an active span?
""" """
import contextlib import contextlib
import enum import enum
import inspect import inspect
@ -414,7 +415,7 @@ def ensure_active_span(
""" """
def ensure_active_span_inner_1( def ensure_active_span_inner_1(
func: Callable[P, R] func: Callable[P, R],
) -> Callable[P, Union[Optional[T], R]]: ) -> Callable[P, Union[Optional[T], R]]:
@wraps(func) @wraps(func)
def ensure_active_span_inner_2( def ensure_active_span_inner_2(
@ -700,7 +701,7 @@ def set_operation_name(operation_name: str) -> None:
@only_if_tracing @only_if_tracing
def force_tracing( def force_tracing(
span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel span: Union["opentracing.Span", _Sentinel] = _Sentinel.sentinel,
) -> None: ) -> None:
"""Force sampling for the active/given span and its children. """Force sampling for the active/given span and its children.
@ -1093,9 +1094,10 @@ def trace_servlet(
# Mypy seems to think that start_context.tag below can be Optional[str], but # Mypy seems to think that start_context.tag below can be Optional[str], but
# that doesn't appear to be correct and works in practice. # that doesn't appear to be correct and works in practice.
request_tags[
SynapseTags.REQUEST_TAG request_tags[SynapseTags.REQUEST_TAG] = (
] = request.request_metrics.start_context.tag # type: ignore[assignment] request.request_metrics.start_context.tag # type: ignore[assignment]
)
# set the tags *after* the servlet completes, in case it decided to # set the tags *after* the servlet completes, in case it decided to
# prioritise the span (tags will get dropped on unprioritised spans) # prioritise the span (tags will get dropped on unprioritised spans)

View file

@ -293,7 +293,7 @@ def wrap_as_background_process(
""" """
def wrap_as_background_process_inner( def wrap_as_background_process_inner(
func: Callable[P, Awaitable[Optional[R]]] func: Callable[P, Awaitable[Optional[R]]],
) -> Callable[P, "defer.Deferred[Optional[R]]"]: ) -> Callable[P, "defer.Deferred[Optional[R]]"]:
@wraps(func) @wraps(func)
def wrap_as_background_process_inner_2( def wrap_as_background_process_inner_2(

View file

@ -304,9 +304,9 @@ class BulkPushRuleEvaluator:
if relation_type == "m.thread" and event.content.get( if relation_type == "m.thread" and event.content.get(
"m.relates_to", {} "m.relates_to", {}
).get("is_falling_back", False): ).get("is_falling_back", False):
related_events["m.in_reply_to"][ related_events["m.in_reply_to"]["im.vector.is_falling_back"] = (
"im.vector.is_falling_back" ""
] = "" )
return related_events return related_events
@ -372,7 +372,8 @@ class BulkPushRuleEvaluator:
gather_results( gather_results(
( (
run_in_background( # type: ignore[call-arg] run_in_background( # type: ignore[call-arg]
self.store.get_number_joined_users_in_room, event.room_id # type: ignore[arg-type] self.store.get_number_joined_users_in_room,
event.room_id, # type: ignore[arg-type]
), ),
run_in_background( run_in_background(
self._get_power_levels_and_sender_level, self._get_power_levels_and_sender_level,

View file

@ -119,7 +119,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload return payload
async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]: # type: ignore[override] async def _handle_request( # type: ignore[override]
self, request: Request, content: JsonDict
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_fed_send_events_parse"): with Measure(self.clock, "repl_fed_send_events_parse"):
room_id = content["room_id"] room_id = content["room_id"]
backfilled = content["backfilled"] backfilled = content["backfilled"]

View file

@ -98,7 +98,9 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
@staticmethod @staticmethod
async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override] async def _serialize_payload( # type: ignore[override]
user_id: str, old_room_id: str, new_room_id: str
) -> JsonDict:
return {} return {}
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
@ -109,7 +111,6 @@ class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
old_room_id: str, old_room_id: str,
new_room_id: str, new_room_id: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await self._store.copy_push_rules_from_room_to_room_for_user( await self._store.copy_push_rules_from_room_to_room_for_user(
old_room_id, new_room_id, user_id old_room_id, new_room_id, user_id
) )

View file

@ -18,8 +18,8 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
"""A replication client for use by synapse workers. """A replication client for use by synapse workers."""
"""
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple

View file

@ -23,6 +23,7 @@
The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side. allowed to be sent by which side.
""" """
import abc import abc
import logging import logging
from typing import List, Optional, Tuple, Type, TypeVar from typing import List, Optional, Tuple, Type, TypeVar

View file

@ -857,7 +857,7 @@ UpdateRow = TypeVar("UpdateRow")
def _batch_updates( def _batch_updates(
updates: Iterable[Tuple[UpdateToken, UpdateRow]] updates: Iterable[Tuple[UpdateToken, UpdateRow]],
) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: ) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
"""Collect stream updates with the same token together """Collect stream updates with the same token together

View file

@ -23,6 +23,7 @@ protocols.
An explanation of this protocol is available in docs/tcp_replication.md An explanation of this protocol is available in docs/tcp_replication.md
""" """
import fcntl import fcntl
import logging import logging
import struct import struct

View file

@ -18,8 +18,7 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
"""The server side of the replication stream. """The server side of the replication stream."""
"""
import logging import logging
import random import random
@ -307,7 +306,7 @@ class ReplicationStreamer:
def _batch_updates( def _batch_updates(
updates: List[Tuple[Token, StreamRow]] updates: List[Tuple[Token, StreamRow]],
) -> List[Tuple[Optional[Token], StreamRow]]: ) -> List[Tuple[Optional[Token], StreamRow]]:
"""Takes a list of updates of form [(token, row)] and sets the token to """Takes a list of updates of form [(token, row)] and sets the token to
None for all rows where the next row has the same token. This is used to None for all rows where the next row has the same token. This is used to

View file

@ -247,7 +247,7 @@ class _StreamFromIdGen(Stream):
def current_token_without_instance( def current_token_without_instance(
current_token: Callable[[], int] current_token: Callable[[], int],
) -> Callable[[str], int]: ) -> Callable[[str], int]:
"""Takes a current token callback function for a single writer stream """Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that that doesn't take an instance name parameter and wraps it in a function that

View file

@ -181,8 +181,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
uses_allowed = body.get("uses_allowed", None) uses_allowed = body.get("uses_allowed", None)
if not ( if not (
uses_allowed is None uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721
or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721
): ):
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,

View file

@ -19,8 +19,8 @@
# #
# #
"""This module contains base REST classes for constructing client v1 servlets. """This module contains base REST classes for constructing client v1 servlets."""
"""
import logging import logging
import re import re
from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast

View file

@ -108,9 +108,9 @@ class AccountDataServlet(RestServlet):
# Push rules are stored in a separate table and must be queried separately. # Push rules are stored in a separate table and must be queried separately.
if account_data_type == AccountDataTypes.PUSH_RULES: if account_data_type == AccountDataTypes.PUSH_RULES:
account_data: Optional[JsonMapping] = ( account_data: Optional[
await self._push_rules_handler.push_rules_for_user(requester.user) JsonMapping
) ] = await self._push_rules_handler.push_rules_for_user(requester.user)
else: else:
account_data = await self.store.get_global_account_data_by_type_for_user( account_data = await self.store.get_global_account_data_by_type_for_user(
user_id, account_data_type user_id, account_data_type

View file

@ -48,9 +48,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_renewed_template = ( self.account_renewed_template = (
hs.config.account_validity.account_validity_account_renewed_template hs.config.account_validity.account_validity_account_renewed_template
) )
self.account_previously_renewed_template = ( self.account_previously_renewed_template = hs.config.account_validity.account_validity_account_previously_renewed_template
hs.config.account_validity.account_validity_account_previously_renewed_template
)
self.invalid_token_template = ( self.invalid_token_template = (
hs.config.account_validity.account_validity_invalid_token_template hs.config.account_validity.account_validity_invalid_token_template
) )

View file

@ -20,6 +20,7 @@
# #
"""This module contains REST servlets to do with event streaming, /events.""" """This module contains REST servlets to do with event streaming, /events."""
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union

View file

@ -19,8 +19,8 @@
# #
# #
""" This module contains REST servlets to do with presence: /presence/<paths> """This module contains REST servlets to do with presence: /presence/<paths>"""
"""
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple

View file

@ -19,7 +19,7 @@
# #
# #
""" This module contains REST servlets to do with profile: /profile/<paths> """ """This module contains REST servlets to do with profile: /profile/<paths>"""
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple

View file

@ -640,12 +640,10 @@ class RegisterRestServlet(RestServlet):
if not password_hash: if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
desired_username = ( desired_username = await (
await ( self.password_auth_provider.get_username_for_registration(
self.password_auth_provider.get_username_for_registration( auth_result,
auth_result, params,
params,
)
) )
) )
@ -696,11 +694,9 @@ class RegisterRestServlet(RestServlet):
session_id session_id
) )
display_name = ( display_name = await (
await ( self.password_auth_provider.get_displayname_for_registration(
self.password_auth_provider.get_displayname_for_registration( auth_result, params
auth_result, params
)
) )
) )

View file

@ -19,7 +19,8 @@
# #
# #
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """This module contains REST servlets to do with rooms: /rooms/<paths>"""
import logging import logging
import re import re
from enum import Enum from enum import Enum

View file

@ -1045,9 +1045,9 @@ class SlidingSyncRestServlet(RestServlet):
serialized_rooms[room_id]["initial"] = room_result.initial serialized_rooms[room_id]["initial"] = room_result.initial
if room_result.unstable_expanded_timeline: if room_result.unstable_expanded_timeline:
serialized_rooms[room_id][ serialized_rooms[room_id]["unstable_expanded_timeline"] = (
"unstable_expanded_timeline" room_result.unstable_expanded_timeline
] = room_result.unstable_expanded_timeline )
# This will be omitted for invite/knock rooms with `stripped_state` # This will be omitted for invite/knock rooms with `stripped_state`
if ( if (
@ -1082,9 +1082,9 @@ class SlidingSyncRestServlet(RestServlet):
# This will be omitted for invite/knock rooms with `stripped_state` # This will be omitted for invite/knock rooms with `stripped_state`
if room_result.prev_batch is not None: if room_result.prev_batch is not None:
serialized_rooms[room_id]["prev_batch"] = ( serialized_rooms[room_id][
await room_result.prev_batch.to_string(self.store) "prev_batch"
) ] = await room_result.prev_batch.to_string(self.store)
# This will be omitted for invite/knock rooms with `stripped_state` # This will be omitted for invite/knock rooms with `stripped_state`
if room_result.num_live is not None: if room_result.num_live is not None:

View file

@ -21,6 +21,7 @@
"""This module contains logic for storing HTTP PUT transactions. This is used """This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API.""" to ensure idempotency when performing PUTs using the REST API."""
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple

View file

@ -191,10 +191,10 @@ class RemoteKey(RestServlet):
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items(): for server_name, key_ids in query.items():
if key_ids: if key_ids:
results: Mapping[str, Optional[FetchKeyResultForRemote]] = ( results: Mapping[
await self.store.get_server_keys_json_for_remote( str, Optional[FetchKeyResultForRemote]
server_name, key_ids ] = await self.store.get_server_keys_json_for_remote(
) server_name, key_ids
) )
else: else:
results = await self.store.get_all_server_keys_json_for_remote( results = await self.store.get_all_server_keys_json_for_remote(

View file

@ -65,9 +65,9 @@ class WellKnownBuilder:
} }
account_management_url = await auth.account_management_url() account_management_url = await auth.account_management_url()
if account_management_url is not None: if account_management_url is not None:
result["org.matrix.msc2965.authentication"][ result["org.matrix.msc2965.authentication"]["account"] = (
"account" account_management_url
] = account_management_url )
if self._config.server.extra_well_known_client_content: if self._config.server.extra_well_known_client_content:
for ( for (

View file

@ -119,7 +119,9 @@ class ResourceLimitsServerNotices:
elif not currently_blocked and limit_msg: elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be. # Room is not notifying of a block, when it ought to be.
await self._apply_limit_block_notification( await self._apply_limit_block_notification(
user_id, limit_msg, limit_type # type: ignore user_id,
limit_msg,
limit_type, # type: ignore
) )
except SynapseError as e: except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e) logger.error("Error sending resource limits server notice: %s", e)

View file

@ -416,7 +416,7 @@ class EventsPersistenceStorageController:
set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled)) set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
async def enqueue( async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]] item: Tuple[str, List[Tuple[EventBase, EventContext]]],
) -> Dict[str, str]: ) -> Dict[str, str]:
room_id, evs_ctxs = item room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue( return await self._event_persist_queue.add_to_queue(
@ -792,9 +792,9 @@ class EventsPersistenceStorageController:
) )
# Remove any events which are prev_events of any existing events. # Remove any events which are prev_events of any existing events.
existing_prevs: Collection[str] = ( existing_prevs: Collection[
await self.persist_events_store._get_events_which_are_prevs(result) str
) ] = await self.persist_events_store._get_events_which_are_prevs(result)
result.difference_update(existing_prevs) result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev # Finally handle the case where the new events have soft-failed prev

View file

@ -238,9 +238,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
INNER JOIN user_ips USING (user_id, access_token, ip) INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip GROUP BY user_id, access_token, ip
HAVING count(*) > 1 HAVING count(*) > 1
""".format( """.format(clause),
clause
),
args, args,
) )
res = cast( res = cast(
@ -373,9 +371,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
LIMIT ? LIMIT ?
) c ) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen) INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % { """ % {"where_clause": where_clause}
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size]) txn.execute(sql, where_args + [batch_size])
rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())

View file

@ -1116,7 +1116,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (start, stop)) txn.execute(sql, (start, stop))
destinations = {d for d, in txn} destinations = {d for (d,) in txn}
to_remove = set() to_remove = set()
for d in destinations: for d in destinations:
try: try:

View file

@ -670,9 +670,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
result["keys"] = keys result["keys"] = keys
device_display_name = None device_display_name = None
if ( if self.hs.config.federation.allow_device_name_lookup_over_federation:
self.hs.config.federation.allow_device_name_lookup_over_federation
):
device_display_name = device.display_name device_display_name = device.display_name
if device_display_name: if device_display_name:
result["device_display_name"] = device_display_name result["device_display_name"] = device_display_name
@ -917,7 +915,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
from_key, from_key,
to_key, to_key,
) )
return {u for u, in rows} return {u for (u,) in rows}
@cancellable @cancellable
async def get_users_whose_devices_changed( async def get_users_whose_devices_changed(
@ -968,7 +966,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
txn.database_engine, "user_id", chunk txn.database_engine, "user_id", chunk
) )
txn.execute(sql % (clause,), [from_key, to_key] + args) txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn) changes.update(user_id for (user_id,) in txn)
return changes return changes
@ -1520,7 +1518,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
args: List[Any], args: List[Any],
) -> Set[str]: ) -> Set[str]:
txn.execute(sql.format(clause=clause), args) txn.execute(sql.format(clause=clause), args)
return {user_id for user_id, in txn} return {user_id for (user_id,) in txn}
changes = set() changes = set()
for chunk in batch_iter(changed_room_ids, 1000): for chunk in batch_iter(changed_room_ids, 1000):
@ -1560,7 +1558,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Set[str]: ) -> Set[str]:
txn.execute(sql, (from_id, to_id)) txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn} return {room_id for (room_id,) in txn}
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_all_device_list_changes", "get_all_device_list_changes",

View file

@ -387,9 +387,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
is_verified, session_data is_verified, session_data
FROM e2e_room_keys FROM e2e_room_keys
WHERE user_id = ? AND version = ? AND (%s) WHERE user_id = ? AND version = ? AND (%s)
""" % ( """ % (" OR ".join(where_clauses))
" OR ".join(where_clauses)
)
txn.execute(sql, params) txn.execute(sql, params)

View file

@ -472,9 +472,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signature_sql = """ signature_sql = """
SELECT user_id, key_id, target_device_id, signature SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s FROM e2e_cross_signing_signatures WHERE %s
""" % ( """ % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
" OR ".join("(" + q + ")" for q in signature_query_clauses)
)
txn.execute(signature_sql, signature_query_params) txn.execute(signature_sql, signature_query_params)
return cast( return cast(
@ -917,9 +915,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
FROM e2e_cross_signing_keys FROM e2e_cross_signing_keys
WHERE %(clause)s WHERE %(clause)s
ORDER BY user_id, keytype, stream_id DESC ORDER BY user_id, keytype, stream_id DESC
""" % { """ % {"clause": clause}
"clause": clause
}
else: else:
# SQLite has special handling for bare columns when using # SQLite has special handling for bare columns when using
# MIN/MAX with a `GROUP BY` clause where it picks the value from # MIN/MAX with a `GROUP BY` clause where it picks the value from
@ -929,9 +925,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
FROM e2e_cross_signing_keys FROM e2e_cross_signing_keys
WHERE %(clause)s WHERE %(clause)s
GROUP BY user_id, keytype GROUP BY user_id, keytype
""" % { """ % {"clause": clause}
"clause": clause
}
txn.execute(sql, params) txn.execute(sql, params)

View file

@ -326,7 +326,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
rows = txn.execute_values(sql, chains.items()) rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows) results.update(r for (r,) in rows)
else: else:
# For SQLite we just fall back to doing a noddy for loop. # For SQLite we just fall back to doing a noddy for loop.
sql = """ sql = """
@ -335,7 +335,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
for chain_id, max_no in chains.items(): for chain_id, max_no in chains.items():
txn.execute(sql, (chain_id, max_no)) txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn) results.update(r for (r,) in txn)
return results return results
@ -645,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
] ]
rows = txn.execute_values(sql, args) rows = txn.execute_values(sql, args)
result.update(r for r, in rows) result.update(r for (r,) in rows)
else: else:
# For SQLite we just fall back to doing a noddy for loop. # For SQLite we just fall back to doing a noddy for loop.
sql = """ sql = """
@ -654,7 +654,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
""" """
for chain_id, (min_no, max_no) in chain_to_gap.items(): for chain_id, (min_no, max_no) in chain_to_gap.items():
txn.execute(sql, (chain_id, min_no, max_no)) txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for r, in txn) result.update(r for (r,) in txn)
return result return result
@ -1220,13 +1220,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
HAVING count(*) > ? HAVING count(*) > ?
ORDER BY count(*) DESC ORDER BY count(*) DESC
LIMIT ? LIMIT ?
""" % ( """ % (where_clause,)
where_clause,
)
query_args = list(itertools.chain(room_id_filter, [min_count, limit])) query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
txn.execute(sql, query_args) txn.execute(sql, query_args)
return [room_id for room_id, in txn] return [room_id for (room_id,) in txn]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
@ -1358,7 +1356,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
txn.execute(sql, (stream_ordering, room_id)) txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn] return [event_id for (event_id,) in txn]
event_ids = await self.db_pool.runInteraction( event_ids = await self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn

View file

@ -1860,9 +1860,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND epa.notif = 1 AND epa.notif = 1
ORDER BY epa.stream_ordering DESC ORDER BY epa.stream_ordering DESC
LIMIT ? LIMIT ?
""" % ( """ % (before_clause,)
before_clause,
)
txn.execute(sql, args) txn.execute(sql, args)
return cast( return cast(
List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall() List[Tuple[str, str, int, int, str, bool, str, int]], txn.fetchall()

View file

@ -429,9 +429,7 @@ class PersistEventsStore:
if event_type == EventTypes.Member and self.is_mine_id(state_key) if event_type == EventTypes.Member and self.is_mine_id(state_key)
] ]
membership_snapshot_shared_insert_values: ( membership_snapshot_shared_insert_values: SlidingSyncMembershipSnapshotSharedInsertValues = {}
SlidingSyncMembershipSnapshotSharedInsertValues
) = {}
membership_infos_to_insert_membership_snapshots: List[ membership_infos_to_insert_membership_snapshots: List[
SlidingSyncMembershipInfo SlidingSyncMembershipInfo
] = [] ] = []
@ -719,7 +717,7 @@ class PersistEventsStore:
keyvalues={}, keyvalues={},
retcols=("event_id",), retcols=("event_id",),
) )
already_persisted_events = {event_id for event_id, in rows} already_persisted_events = {event_id for (event_id,) in rows}
state_events = [ state_events = [
event event
for event in state_events for event in state_events
@ -1830,12 +1828,8 @@ class PersistEventsStore:
if sliding_sync_table_changes.to_insert_membership_snapshots: if sliding_sync_table_changes.to_insert_membership_snapshots:
# Update the `sliding_sync_membership_snapshots` table # Update the `sliding_sync_membership_snapshots` table
# #
sliding_sync_snapshot_keys = ( sliding_sync_snapshot_keys = sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys()
sliding_sync_table_changes.membership_snapshot_shared_insert_values.keys() sliding_sync_snapshot_values = sliding_sync_table_changes.membership_snapshot_shared_insert_values.values()
)
sliding_sync_snapshot_values = (
sliding_sync_table_changes.membership_snapshot_shared_insert_values.values()
)
# We need to insert/update regardless of whether we have # We need to insert/update regardless of whether we have
# `sliding_sync_snapshot_keys` because there are other fields in the `ON # `sliding_sync_snapshot_keys` because there are other fields in the `ON
# CONFLICT` upsert to run (see inherit case (explained in # CONFLICT` upsert to run (see inherit case (explained in
@ -3361,7 +3355,7 @@ class PersistEventsStore:
) )
potential_backwards_extremities.difference_update( potential_backwards_extremities.difference_update(
e for e, in existing_events_outliers e for (e,) in existing_events_outliers
) )
if potential_backwards_extremities: if potential_backwards_extremities:

View file

@ -647,7 +647,8 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
room_ids = {row[0] for row in rows} room_ids = {row[0] for row in rows}
for room_id in room_ids: for room_id in room_ids:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,) # type: ignore[attr-defined] self.get_latest_event_ids_in_room.invalidate, # type: ignore[attr-defined]
(room_id,),
) )
self.db_pool.simple_delete_many_txn( self.db_pool.simple_delete_many_txn(
@ -2065,9 +2066,7 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
) )
# Map of values to insert/update in the `sliding_sync_membership_snapshots` table # Map of values to insert/update in the `sliding_sync_membership_snapshots` table
sliding_sync_membership_snapshots_insert_map: ( sliding_sync_membership_snapshots_insert_map: SlidingSyncMembershipSnapshotSharedInsertValues = {}
SlidingSyncMembershipSnapshotSharedInsertValues
) = {}
if membership == Membership.JOIN: if membership == Membership.JOIN:
# If we're still joined, we can pull from current state. # If we're still joined, we can pull from current state.
current_state_ids_map: StateMap[ current_state_ids_map: StateMap[
@ -2149,14 +2148,15 @@ class EventsBackgroundUpdatesStore(StreamWorkerStore, StateDeltasStore, SQLBaseS
# membership (i.e. the room shouldn't disappear if your using the # membership (i.e. the room shouldn't disappear if your using the
# `is_encrypted` filter and you leave). # `is_encrypted` filter and you leave).
if membership in (Membership.LEAVE, Membership.BAN) and is_outlier: if membership in (Membership.LEAVE, Membership.BAN) and is_outlier:
invite_or_knock_event_id, invite_or_knock_membership = ( (
await self.db_pool.runInteraction( invite_or_knock_event_id,
"sliding_sync_membership_snapshots_bg_update._find_previous_membership", invite_or_knock_membership,
_find_previous_membership_txn, ) = await self.db_pool.runInteraction(
room_id, "sliding_sync_membership_snapshots_bg_update._find_previous_membership",
user_id, _find_previous_membership_txn,
membership_event_id, room_id,
) user_id,
membership_event_id,
) )
# Pull from the stripped state on the invite/knock event # Pull from the stripped state on the invite/knock event
@ -2484,9 +2484,7 @@ def _resolve_stale_data_in_sliding_sync_joined_rooms_table(
"progress_json": "{}", "progress_json": "{}",
}, },
) )
depends_on = ( depends_on = _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE
_BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE
)
# Now kick-off the background update to catch-up with what we missed while Synapse # Now kick-off the background update to catch-up with what we missed while Synapse
# was downgraded. # was downgraded.

View file

@ -1665,7 +1665,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.database_engine, "e.event_id", event_ids txn.database_engine, "e.event_id", event_ids
) )
txn.execute(sql + clause, args) txn.execute(sql + clause, args)
found_events = {eid for eid, in txn} found_events = {eid for (eid,) in txn}
# ... and then we can update the results for each key # ... and then we can update the results for each key
return {eid: (eid in found_events) for eid in event_ids} return {eid: (eid in found_events) for eid in event_ids}
@ -1864,9 +1864,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?" " LIMIT ?"
) )
txn.execute(sql, (-last_id, -current_id, instance_name, limit)) txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates: List[Tuple[int, Tuple[str, str, str, str, str, str]]] = ( new_event_updates: List[
[] Tuple[int, Tuple[str, str, str, str, str, str]]
) ] = []
row: Tuple[int, str, str, str, str, str, str] row: Tuple[int, str, str, str, str, str, str]
# Type safety: iterating over `txn` yields `Tuple`, i.e. # Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a

View file

@ -201,7 +201,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
txn.execute_batch( txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)" "INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)", " VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems], [(room_id, event_id) for (event_id,) in new_backwards_extrems],
) )
logger.info("[purge] finding state groups referenced by deleted events") logger.info("[purge] finding state groups referenced by deleted events")
@ -215,7 +215,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
""" """
) )
referenced_state_groups = {sg for sg, in txn} referenced_state_groups = {sg for (sg,) in txn}
logger.info( logger.info(
"[purge] found %i referenced state groups", len(referenced_state_groups) "[purge] found %i referenced state groups", len(referenced_state_groups)
) )

View file

@ -762,7 +762,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return [room_id for room_id, in txn] return [room_id for (room_id,) in txn]
results: List[str] = [] results: List[str] = []
for batch in batch_iter(room_ids, 1000): for batch in batch_iter(room_ids, 1000):
@ -1030,9 +1030,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s SELECT max(stream_ordering) WHERE %s
) )
""" % ( """ % (clause,)
clause,
)
txn.execute(sql, [room_id] + list(args)) txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall() rows = txn.fetchall()

View file

@ -1250,9 +1250,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
SELECT address, session_id, medium, client_secret, SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s FROM threepid_validation_session WHERE %s
""" % ( """ % (" AND ".join("%s = ?" % k for k in keyvalues.keys()),)
" AND ".join("%s = ?" % k for k in keyvalues.keys()),
)
if validated is not None: if validated is not None:
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")

View file

@ -1608,9 +1608,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
FROM event_reports AS er FROM event_reports AS er
JOIN room_stats_state ON room_stats_state.room_id = er.room_id JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{} {}
""".format( """.format(where_clause)
where_clause
)
txn.execute(sql, args) txn.execute(sql, args)
count = cast(Tuple[int], txn.fetchone())[0] count = cast(Tuple[int], txn.fetchone())[0]

View file

@ -232,9 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND m.room_id = c.room_id AND m.room_id = c.room_id
AND m.user_id = c.state_key AND m.user_id = c.state_key
WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ? AND %s
""" % ( """ % (clause,)
clause,
)
txn.execute(sql, (room_id, Membership.JOIN, *ids)) txn.execute(sql, (room_id, Membership.JOIN, *ids))
return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn} return {r[0]: ProfileInfo(display_name=r[1], avatar_url=r[2]) for r in txn}
@ -531,9 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
WHERE WHERE
user_id = ? user_id = ?
AND %s AND %s
""" % ( """ % (clause,)
clause,
)
txn.execute(sql, (user_id, *args)) txn.execute(sql, (user_id, *args))
results = [ results = [
@ -813,7 +809,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
""" """
txn.execute(sql, (user_id, *args)) txn.execute(sql, (user_id, *args))
return {u: True for u, in txn} return {u: True for (u,) in txn}
to_return = {} to_return = {}
for batch_user_ids in batch_iter(other_user_ids, 1000): for batch_user_ids in batch_iter(other_user_ids, 1000):
@ -1031,7 +1027,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
AND room_id = ? AND room_id = ?
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
return {d for d, in txn} return {d for (d,) in txn}
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn "get_current_hosts_in_room", get_current_hosts_in_room_txn
@ -1099,7 +1095,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
""" """
txn.execute(sql, (room_id,)) txn.execute(sql, (room_id,))
# `server_domain` will be `NULL` for malformed MXIDs with no colons. # `server_domain` will be `NULL` for malformed MXIDs with no colons.
return tuple(d for d, in txn if d is not None) return tuple(d for (d,) in txn if d is not None)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
@ -1316,9 +1312,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
room_id = ? AND membership = ? room_id = ? AND membership = ?
AND NOT (%s) AND NOT (%s)
LIMIT 1 LIMIT 1
""" % ( """ % (clause,)
clause,
)
def _is_local_host_in_room_ignoring_users_txn( def _is_local_host_in_room_ignoring_users_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
@ -1464,10 +1458,12 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
self, progress: JsonDict, batch_size: int self, progress: JsonDict, batch_size: int
) -> int: ) -> int:
target_min_stream_id = progress.get( target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined] "target_min_stream_id_inclusive",
self._min_stream_order_on_start, # type: ignore[attr-defined]
) )
max_stream_id = progress.get( max_stream_id = progress.get(
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined] "max_stream_id_exclusive",
self._stream_order_on_start + 1, # type: ignore[attr-defined]
) )
def add_membership_profile_txn(txn: LoggingTransaction) -> int: def add_membership_profile_txn(txn: LoggingTransaction) -> int:

View file

@ -177,9 +177,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
AND (%s) AND (%s)
ORDER BY stream_ordering DESC ORDER BY stream_ordering DESC
LIMIT ? LIMIT ?
""" % ( """ % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
" OR ".join("type = '%s'" % (t,) for t in TYPES),
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))

View file

@ -535,7 +535,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="check_if_events_in_current_state", desc="check_if_events_in_current_state",
) )
return frozenset(event_id for event_id, in rows) return frozenset(event_id for (event_id,) in rows)
# FIXME: how should this be cached? # FIXME: how should this be cached?
@cancellable @cancellable

View file

@ -161,7 +161,7 @@ class StatsStore(StateDeltasStore):
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_user_id, batch_size)) txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn] return [r for (r,) in txn]
users_to_work_on = await self.db_pool.runInteraction( users_to_work_on = await self.db_pool.runInteraction(
"_populate_stats_process_users", _get_next_batch "_populate_stats_process_users", _get_next_batch
@ -207,7 +207,7 @@ class StatsStore(StateDeltasStore):
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_room_id, batch_size)) txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn] return [r for (r,) in txn]
rooms_to_work_on = await self.db_pool.runInteraction( rooms_to_work_on = await self.db_pool.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch "populate_stats_rooms_get_batch", _get_next_batch
@ -751,9 +751,7 @@ class StatsStore(StateDeltasStore):
LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id
{} {}
GROUP BY lmr.user_id, displayname GROUP BY lmr.user_id, displayname
""".format( """.format(where_clause)
where_clause
)
# SQLite does not support SELECT COUNT(*) OVER() # SQLite does not support SELECT COUNT(*) OVER()
sql = """ sql = """

View file

@ -21,7 +21,7 @@
# #
# #
""" This module is responsible for getting events from the DB for pagination """This module is responsible for getting events from the DB for pagination
and event streaming. and event streaming.
The order it returns events in depend on whether we are streaming forwards or The order it returns events in depend on whether we are streaming forwards or
@ -1122,9 +1122,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
AND e.stream_ordering > ? AND e.stream_ordering <= ? AND e.stream_ordering > ? AND e.stream_ordering <= ?
%s %s
ORDER BY e.stream_ordering ASC ORDER BY e.stream_ordering ASC
""" % ( """ % (ignore_room_clause,)
ignore_room_clause,
)
txn.execute(sql, args) txn.execute(sql, args)

View file

@ -224,9 +224,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
SELECT room_id, events FROM %s SELECT room_id, events FROM %s
ORDER BY events DESC ORDER BY events DESC
LIMIT 250 LIMIT 250
""" % ( """ % (TEMP_TABLE + "_rooms",)
TEMP_TABLE + "_rooms",
)
txn.execute(sql) txn.execute(sql)
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall()) rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())

View file

@ -767,7 +767,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
remaining_state_groups = { remaining_state_groups = {
state_group state_group
for state_group, in rows for (state_group,) in rows
if state_group not in state_groups_to_delete if state_group not in state_groups_to_delete
} }

View file

@ -607,7 +607,7 @@ def _apply_module_schema_files(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", "SELECT file FROM applied_module_schemas WHERE module_name = ?",
(modname,), (modname,),
) )
applied_deltas = {d for d, in cur} applied_deltas = {d for (d,) in cur}
for name, stream in names_and_streams: for name, stream in names_and_streams:
if name in applied_deltas: if name in applied_deltas:
continue continue
@ -710,7 +710,7 @@ def _get_or_create_schema_state(
"SELECT file FROM applied_schema_deltas WHERE version >= ?", "SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,), (current_version,),
) )
applied_deltas = tuple(d for d, in txn) applied_deltas = tuple(d for (d,) in txn)
return _SchemaState( return _SchemaState(
current_version=current_version, current_version=current_version,

View file

@ -41,8 +41,6 @@ def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) ->
(user_id, filter_id); (user_id, filter_id);
DROP TABLE user_filters; DROP TABLE user_filters;
ALTER TABLE user_filters_migration RENAME TO user_filters; ALTER TABLE user_filters_migration RENAME TO user_filters;
""" % ( """ % (select_clause,)
select_clause,
)
execute_statements_from_stream(cur, StringIO(sql)) execute_statements_from_stream(cur, StringIO(sql))

View file

@ -23,6 +23,7 @@
This migration handles the process of changing the type of `room_depth.min_depth` to This migration handles the process of changing the type of `room_depth.min_depth` to
a BIGINT. a BIGINT.
""" """
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine

View file

@ -25,6 +25,7 @@ This migration adds triggers to the partial_state_events tables to enforce uniqu
Triggers cannot be expressed in .sql files, so we have to use a separate file. Triggers cannot be expressed in .sql files, so we have to use a separate file.
""" """
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine

View file

@ -26,6 +26,7 @@ for its completion can be removed.
Note the background job must still remain defined in the database class. Note the background job must still remain defined in the database class.
""" """
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.engines import BaseDatabaseEngine

Some files were not shown because too many files have changed in this diff Show more