Fix retrieving cross-signing signatures in /user/devices/{userId} (#2368)

* Fix retrieving cross-signing signatures in `/user/devices/{userId}`

We need to know the target device IDs in order to get the signatures and we weren't populating those.

* Fix up signature retrieval

* Fix SQLite

* Always include the target's own signatures as well as the requesting user
This commit is contained in:
Neil Alexander 2022-04-22 14:58:24 +01:00 committed by GitHub
parent c07f347f00
commit 6d78c4d67d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 52 additions and 26 deletions

View file

@ -43,6 +43,9 @@ func GetUserDevices(
},
}
sigRes := &keyapi.QuerySignaturesResponse{}
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}
keyAPI.QuerySignatures(req.Context(), sigReq, sigRes)
response := gomatrixserverlib.RespUserDevices{

View file

@ -455,10 +455,10 @@ func (a *KeyInternalAPI) processOtherSignatures(
func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
) {
for userID := range req.UserToDevices {
keys, err := a.DB.CrossSigningKeysForUser(ctx, userID)
for targetUserID := range req.UserToDevices {
keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil {
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", userID)
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
continue
}
@ -469,9 +469,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
break
}
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, keyID)
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", userID, keyID)
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
continue
}
@ -491,7 +491,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
case req.UserID != "" && originUserID == req.UserID:
// Include signatures that we created
appendSignature(originUserID, originKeyID, signature)
case originUserID == userID:
case originUserID == targetUserID:
// Include signatures that were created by the person whose key
// we are processing
appendSignature(originUserID, originKeyID, signature)
@ -501,13 +501,13 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
switch keyType {
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
res.MasterKeys[userID] = key
res.MasterKeys[targetUserID] = key
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
res.SelfSigningKeys[userID] = key
res.SelfSigningKeys[targetUserID] = key
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
res.UserSigningKeys[userID] = key
res.UserSigningKeys[targetUserID] = key
}
}
}
@ -546,7 +546,8 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
}
for _, targetKeyID := range forTargetUser {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetKeyID)
// Get own signatures only.
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),

View file

@ -313,9 +313,31 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
// Finally, append signatures that we know about
// TODO: This is horrible because we need to round-trip the signature from
// JSON, add the signatures and marshal it again, for some reason?
for userID, forUserID := range res.DeviceKeys {
for keyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, gomatrixserverlib.KeyID(keyID))
for targetUserID, masterKey := range res.MasterKeys {
for targetKeyID := range masterKey.Keys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil {
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
continue
}
for sourceUserID, forSourceUser := range sigMap {
for sourceKeyID, sourceSig := range forSourceUser {
if _, ok := masterKey.Signatures[sourceUserID]; !ok {
masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
}
}
}
}
for targetUserID, forUserID := range res.DeviceKeys {
for targetKeyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
if err != nil {
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
@ -339,7 +361,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
}
if js, err := json.Marshal(deviceKey); err == nil {
res.DeviceKeys[userID][keyID] = js
res.DeviceKeys[targetUserID][targetKeyID] = js
}
}
}

View file

@ -81,7 +81,7 @@ type Database interface {
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error

View file

@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE target_user_id = $1 AND target_key_id = $2"
" WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
@ -72,9 +72,9 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro
}
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
if err != nil {
return nil, err
}

View file

@ -190,7 +190,7 @@ func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (
keyID: key,
},
}
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, keyID)
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
if err != nil {
continue
}
@ -219,8 +219,8 @@ func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID strin
}
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, targetUserID, targetKeyID)
func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
}
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.

View file

@ -39,7 +39,7 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE target_user_id = $1 AND target_key_id = $2"
" WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
@ -71,13 +71,13 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error)
}
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
r = types.CrossSigningSigMap{}
for rows.Next() {
var userID string

View file

@ -64,7 +64,7 @@ type CrossSigningKeys interface {
}
type CrossSigningSigs interface {
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
}