diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 527990cf9..bfb2037f8 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -308,8 +308,12 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req // Finally, generate a notification that we updated the signatures. for userID := range req.Signatures { + masterKey := queryRes.MasterKeys[userID] + selfSigningKey := queryRes.SelfSigningKeys[userID] update := eduserverAPI.CrossSigningKeyUpdate{ - UserID: userID, + UserID: userID, + MasterKey: &masterKey, + SelfSigningKey: &selfSigningKey, } if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil { res.Error = &api.KeyError{ diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 259249217..371dda6d0 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -243,49 +243,45 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if serverName == a.ThisServer { - deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to query local device keys: %s", err), - } - return + deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to query local device keys: %s", err), } - - // pull out display names after we have the keys so we handle wildcards correctly - var dids []string - for _, dk := range deviceKeys { - dids = append(dids, dk.DeviceID) - } - var queryRes userapi.QueryDeviceInfosResponse - err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ - DeviceIDs: dids, - }, &queryRes) - if err != nil { - util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") - } - - if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) - } - for _, dk := range deviceKeys { - if len(dk.KeyJSON) == 0 { - continue // don't include blank keys - } - // inject display name if known (either locally or remotely) - displayName := dk.DisplayName - if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { - displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName - } - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{displayName}) - res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON - } - } else { - domainToDeviceKeys[domain] = make(map[string][]string) - domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) + return } + + // pull out display names after we have the keys so we handle wildcards correctly + var dids []string + for _, dk := range deviceKeys { + dids = append(dids, dk.DeviceID) + } + var queryRes userapi.QueryDeviceInfosResponse + err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{ + DeviceIDs: dids, + }, &queryRes) + if err != nil { + util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing") + } + + if res.DeviceKeys[userID] == nil { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + } + for _, dk := range deviceKeys { + if len(dk.KeyJSON) == 0 { + continue // don't include blank keys + } + // inject display name if known (either locally or remotely) + displayName := dk.DisplayName + if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { + displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName + } + dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { + DisplayName string `json:"device_display_name,omitempty"` + }{displayName}) + res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON + } + // work out if our cross-signing request for this user was // satisfied, if not add them to the list of things to fetch if _, ok := res.MasterKeys[userID]; !ok { @@ -326,8 +322,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques if err = json.Unmarshal(key, &deviceKey); err != nil { continue } + if deviceKey.Signatures == nil { + deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } for sourceUserID, forSourceUser := range sigMap { for sourceKeyID, sourceSig := range forSourceUser { + if _, ok := deviceKey.Signatures[sourceUserID]; !ok { + deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig } }