diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 774e71dd3..873a051cd 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -255,13 +255,32 @@ func (r *Inputer) processRoomEvent( hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - if err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { + // Something went wrong with retrieving the missing state, so we can't + // really do anything with the event other than reject it at this point. isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) + } else if stateSnapshot != nil { + // We retrieved some state and we ended up having to call /state_ids for + // the new event in question (probably because closing the gap by using + // /get_missing_events didn't do what we hoped) so we'll instead overwrite + // the state snapshot with the newly resolved state. + missingPrev = false + input.HasState = true + input.StateEventIDs = make([]string, 0, len(stateSnapshot.StateEvents)) + for _, e := range stateSnapshot.StateEvents { + input.StateEventIDs = append(input.StateEventIDs, e.EventID()) + } } else { + // We retrieved some state and it would appear that rolling forward the + // state did everything we needed it to do, so we can just resolve the + // state for the event in the normal way. missingPrev = false } } else { + // We're missing prev events or state for the event, but for some reason + // we don't know any servers to ask. In this case we can't do anything but + // reject the event and hope that it gets unrejected later. isRejected = true rejectionErr = fmt.Errorf("missing prev events and no other servers to ask") } @@ -299,7 +318,7 @@ func (r *Inputer) processRoomEvent( return rollbackTransaction, fmt.Errorf("updater.RoomInfo missing for room %s", event.RoomID()) } - if !missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0 { + if input.HasState || (!missingPrev && stateAtEvent.BeforeStateSnapshotNID == 0) { // We haven't calculated a state for this event yet. // Lets calculate one. err = r.calculateAndSetState(ctx, updater, input, roomInfo, &stateAtEvent, event, isRejected) diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 497c049dc..19771d4bd 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -40,9 +40,10 @@ type missingStateReq struct { // processEventWithMissingState is the entrypoint for a missingStateReq // request, as called from processRoomEvent. +// nolint:gocyclo func (t *missingStateReq) processEventWithMissingState( ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, -) error { +) (*parsedRespState, error) { // We are missing the previous events for this events. // This means that there is a gap in our view of the history of the // room. There two ways that we can handle such a gap: @@ -68,15 +69,15 @@ func (t *missingStateReq) processEventWithMissingState( // - fill in the gap completely then process event `e` returning no backwards extremity // - fail to fill in the gap and tell us to terminate the transaction err=not nil // - fail to fill in the gap and tell us to fetch state at the new backwards extremity, and to not terminate the transaction - newEvents, isGapFilled, err := t.getMissingEvents(ctx, e, roomVersion) + newEvents, isGapFilled, prevStatesKnown, err := t.getMissingEvents(ctx, e, roomVersion) if err != nil { - return fmt.Errorf("t.getMissingEvents: %w", err) + return nil, fmt.Errorf("t.getMissingEvents: %w", err) } if len(newEvents) == 0 { - return fmt.Errorf("expected to find missing events but didn't") + return nil, fmt.Errorf("expected to find missing events but didn't") } if isGapFilled { - logger.Infof("gap filled by /get_missing_events, injecting %d new events", len(newEvents)) + logger.Infof("Gap filled by /get_missing_events, injecting %d new events", len(newEvents)) // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { @@ -88,82 +89,31 @@ func (t *missingStateReq) processEventWithMissingState( }) if err != nil { if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (filling gap): %w", err) + return nil, fmt.Errorf("t.inputer.processRoomEvent (filling gap): %w", err) } } } - return nil } + // If we filled the gap *and* we know the state before the prev events + // then there's nothing else to do, we have everything we need to deal + // with the new event. + if isGapFilled && prevStatesKnown { + logger.Infof("Gap filled and state found for all prev events") + return nil, nil + } + + // Otherwise, if we've reached this point, it's possible that we've + // either not closed the gap, or we did but we still don't seem to + // know the events before the new event. Start by looking up the + // state at the event at the back of the gap and we'll try to roll + // forward the state first. backwardsExtremity := newEvents[0] newEvents = newEvents[1:] - type respState struct { - // A snapshot is considered trustworthy if it came from our own roomserver. - // That's because the state will have been through state resolution once - // already in QueryStateAfterEvent. - trustworthy bool - *parsedRespState - } - - // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. - // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query - // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. - var states []*respState - for _, prevEventID := range backwardsExtremity.PrevEventIDs() { - // Look up what the state is after the backward extremity. This will either - // come from the roomserver, if we know all the required events, or it will - // come from a remote server via /state_ids if not. - prevState, trustworthy, lerr := t.lookupStateAfterEvent(ctx, roomVersion, backwardsExtremity.RoomID(), prevEventID) - if lerr != nil { - logger.WithError(lerr).Errorf("Failed to lookup state after prev_event: %s", prevEventID) - return lerr - } - // Append the state onto the collected state. We'll run this through the - // state resolution next. - states = append(states, &respState{trustworthy, prevState}) - } - - // Now that we have collected all of the state from the prev_events, we'll - // run the state through the appropriate state resolution algorithm for the - // room if needed. This does a couple of things: - // 1. Ensures that the state is deduplicated fully for each state-key tuple - // 2. Ensures that we pick the latest events from both sets, in the case that - // one of the prev_events is quite a bit older than the others - resolvedState := &parsedRespState{} - switch len(states) { - case 0: - extremityIsCreate := backwardsExtremity.Type() == gomatrixserverlib.MRoomCreate && backwardsExtremity.StateKeyEquals("") - if !extremityIsCreate { - // There are no previous states and this isn't the beginning of the - // room - this is an error condition! - logger.Errorf("Failed to lookup any state after prev_events") - return fmt.Errorf("expected %d states but got %d", len(backwardsExtremity.PrevEventIDs()), len(states)) - } - case 1: - // There's only one previous state - if it's trustworthy (came from a - // local state snapshot which will already have been through state res), - // use it as-is. There's no point in resolving it again. - if states[0].trustworthy { - resolvedState = states[0].parsedRespState - break - } - // Otherwise, if it isn't trustworthy (came from federation), run it through - // state resolution anyway for safety, in case there are duplicates. - fallthrough - default: - respStates := make([]*parsedRespState, len(states)) - for i := range states { - respStates[i] = states[i].parsedRespState - } - // There's more than one previous state - run them all through state res - t.roomsMu.Lock(e.RoomID()) - resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, backwardsExtremity) - t.roomsMu.Unlock(e.RoomID()) - if err != nil { - logger.WithError(err).Errorf("Failed to resolve state conflicts for event %s", backwardsExtremity.EventID()) - return err - } + resolvedState, err := t.lookupResolvedStateBeforeEvent(ctx, backwardsExtremity, roomVersion) + if err != nil { + return nil, fmt.Errorf("t.lookupState (backwards extremity): %w", err) } hadEvents := map[string]bool{} @@ -173,30 +123,37 @@ func (t *missingStateReq) processEventWithMissingState( } t.hadEventsMutex.Unlock() - // Send outliers first so we can send the new backwards extremity without causing errors - outliers, err := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) - if err != nil { - return err - } - var outlierRoomEvents []api.InputRoomEvent - for _, outlier := range outliers { - if hadEvents[outlier.EventID()] { - continue + sendOutliers := func(resolvedState *parsedRespState) error { + outliers, oerr := gomatrixserverlib.OrderAuthAndStateEvents(resolvedState.AuthEvents, resolvedState.StateEvents, roomVersion) + if oerr != nil { + return fmt.Errorf("gomatrixserverlib.OrderAuthAndStateEvents: %w", oerr) } - outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ - Kind: api.KindOutlier, - Event: outlier.Headered(roomVersion), - Origin: t.origin, - }) - } - // TODO: we could do this concurrently? - for _, ire := range outlierRoomEvents { - _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) - if err != nil { - if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) + var outlierRoomEvents []api.InputRoomEvent + for _, outlier := range outliers { + if hadEvents[outlier.EventID()] { + continue + } + outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{ + Kind: api.KindOutlier, + Event: outlier.Headered(roomVersion), + Origin: t.origin, + }) + } + for _, ire := range outlierRoomEvents { + _, err = t.inputer.processRoomEvent(ctx, t.db, &ire) + if err != nil { + if _, ok := err.(types.RejectedError); !ok { + return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) + } } } + return nil + } + + // Send outliers first so we can send the state along with the new backwards + // extremity without any missing auth events. + if err = sendOutliers(resolvedState); err != nil { + return nil, fmt.Errorf("sendOutliers: %w", err) } // Now send the backward extremity into the roomserver with the @@ -217,7 +174,7 @@ func (t *missingStateReq) processEventWithMissingState( }) if err != nil { if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (backward extremity): %w", err) + return nil, fmt.Errorf("t.inputer.processRoomEvent (backward extremity): %w", err) } } @@ -234,12 +191,109 @@ func (t *missingStateReq) processEventWithMissingState( }) if err != nil { if _, ok := err.(types.RejectedError); !ok { - return fmt.Errorf("t.inputer.processRoomEvent (fast forward): %w", err) + return nil, fmt.Errorf("t.inputer.processRoomEvent (fast forward): %w", err) } } } - return nil + // Finally, check again if we know everything we need to know in order to + // make forward progress. If the prev state is known then we consider the + // rolled forward state to be sufficient — we now know all of the state + // before the prev events. If we don't then we need to look up the state + // before the new event as well, otherwise we will never make any progress. + if t.isPrevStateKnown(ctx, e) { + return nil, nil + } + + // If we still haven't got the state for the prev events then we'll go and + // ask the federation for it if needed. + resolvedState, err = t.lookupResolvedStateBeforeEvent(ctx, e, roomVersion) + if err != nil { + return nil, fmt.Errorf("t.lookupState (new event): %w", err) + } + + // Send the outliers for the retrieved state. + if err = sendOutliers(resolvedState); err != nil { + return nil, fmt.Errorf("sendOutliers: %w", err) + } + + // Then return the resolved state, for which the caller can replace the + // HasState with the event IDs to create a new state snapshot when we + // process the new event. + return resolvedState, nil +} + +func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) { + type respState struct { + // A snapshot is considered trustworthy if it came from our own roomserver. + // That's because the state will have been through state resolution once + // already in QueryStateAfterEvent. + trustworthy bool + *parsedRespState + } + + // at this point we know we're going to have a gap: we need to work out the room state at the new backwards extremity. + // Therefore, we cannot just query /state_ids with this event to get the state before. Instead, we need to query + // the state AFTER all the prev_events for this event, then apply state resolution to that to get the state before the event. + var states []*respState + for _, prevEventID := range e.PrevEventIDs() { + // Look up what the state is after the backward extremity. This will either + // come from the roomserver, if we know all the required events, or it will + // come from a remote server via /state_ids if not. + prevState, trustworthy, err := t.lookupStateAfterEvent(ctx, roomVersion, e.RoomID(), prevEventID) + if err != nil { + return nil, fmt.Errorf("t.lookupStateAfterEvent: %w", err) + } + // Append the state onto the collected state. We'll run this through the + // state resolution next. + states = append(states, &respState{trustworthy, prevState}) + } + + // Now that we have collected all of the state from the prev_events, we'll + // run the state through the appropriate state resolution algorithm for the + // room if needed. This does a couple of things: + // 1. Ensures that the state is deduplicated fully for each state-key tuple + // 2. Ensures that we pick the latest events from both sets, in the case that + // one of the prev_events is quite a bit older than the others + resolvedState := &parsedRespState{} + switch len(states) { + case 0: + extremityIsCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + if !extremityIsCreate { + // There are no previous states and this isn't the beginning of the + // room - this is an error condition! + return nil, fmt.Errorf("expected %d states but got %d", len(e.PrevEventIDs()), len(states)) + } + case 1: + // There's only one previous state - if it's trustworthy (came from a + // local state snapshot which will already have been through state res), + // use it as-is. There's no point in resolving it again. Only trust a + // trustworthy state snapshot if it actually contains some state for all + // non-create events, otherwise we need to resolve what came from federation. + isCreate := e.Type() == gomatrixserverlib.MRoomCreate && e.StateKeyEquals("") + if states[0].trustworthy && (isCreate || len(states[0].StateEvents) > 0) { + resolvedState = states[0].parsedRespState + break + } + // Otherwise, if it isn't trustworthy (came from federation), run it through + // state resolution anyway for safety, in case there are duplicates. + fallthrough + default: + respStates := make([]*parsedRespState, len(states)) + for i := range states { + respStates[i] = states[i].parsedRespState + } + // There's more than one previous state - run them all through state res + var err error + t.roomsMu.Lock(e.RoomID()) + resolvedState, err = t.resolveStatesAndCheck(ctx, roomVersion, respStates, e) + t.roomsMu.Unlock(e.RoomID()) + if err != nil { + return nil, fmt.Errorf("t.resolveStatesAndCheck: %w", err) + } + } + + return resolvedState, nil } // lookupStateAfterEvent returns the room state after `eventID`, which is the state before eventID with the state of `eventID` (if it's a state event) @@ -408,7 +462,7 @@ retryAllowedState: // get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject, // without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events -func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled bool, err error) { +func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) { logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) latest := t.db.LatestEvents() @@ -435,7 +489,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve if errors.Is(err, context.DeadlineExceeded) { select { case <-ctx.Done(): // the parent request context timed out - return nil, false, context.DeadlineExceeded + return nil, false, false, context.DeadlineExceeded default: // this request exceed its own timeout continue } @@ -448,7 +502,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, len(t.servers), ) - return nil, false, missingPrevEventsError{ + return nil, false, false, missingPrevEventsError{ eventID: e.EventID(), err: err, } @@ -457,17 +511,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve // Make sure events from the missingResp are using the cache - missing events // will be added and duplicates will be removed. logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) - missingEvents := make([]*gomatrixserverlib.Event, len(missingResp.Events)) - for i, evJSON := range missingResp.Events { - ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(evJSON, roomVersion) - if err != nil { - logger.WithError(err).WithField("event", string(evJSON)).Warn("NewEventFromUntrustedJSON: failed") - return nil, false, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } - missingEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) + for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { + missingEvents = append(missingEvents, t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()) } // topologically sort and sanity check that we are making forward progress @@ -489,27 +535,51 @@ Event: "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", t.origin, ) - return nil, false, missingPrevEventsError{ + return nil, false, false, missingPrevEventsError{ eventID: e.EventID(), err: err, } } if len(newEvents) == 0 { - return nil, false, nil // TODO: error instead? + return nil, false, false, nil // TODO: error instead? } - // now check if we can fill the gap. Look to see if we have state snapshot IDs for the earliest event earliestNewEvent := newEvents[0] - if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 { - if earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { - // we got to the beginning of the room so there will be no state! It's all good we can process this - return newEvents, true, nil - } - // we don't have the state at this earliest event from /g_m_e so we won't have state for later events either - return newEvents, false, nil + + // If we retrieved back to the beginning of the room then there's nothing else + // to do - we closed the gap. + if len(earliestNewEvent.PrevEventIDs()) == 0 && earliestNewEvent.Type() == gomatrixserverlib.MRoomCreate && earliestNewEvent.StateKeyEquals("") { + return newEvents, true, t.isPrevStateKnown(ctx, e), nil } - // StateAtEventIDs returned some kind of state for the earliest event so we can fill in the gap! - return newEvents, true, nil + + // If our backward extremity was not a known event to us then we obviously didn't + // close the gap. + if state, err := t.db.StateAtEventIDs(ctx, []string{earliestNewEvent.EventID()}); err != nil || len(state) == 0 && state[0].BeforeStateSnapshotNID == 0 { + return newEvents, false, false, nil + } + + // At this point we are satisfied that we know the state both at the earliest + // retrieved event and at the prev events of the new event. + return newEvents, true, t.isPrevStateKnown(ctx, e), nil +} + +func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserverlib.Event) bool { + expected := len(e.PrevEventIDs()) + state, err := t.db.StateAtEventIDs(ctx, e.PrevEventIDs()) + if err != nil || len(state) != expected { + // We didn't get as many state snapshots as we expected, or there was an error, + // so we haven't completely solved the problem for the new event. + return false + } + // Check to see if we have a populated state snapshot for all of the prev events. + for _, stateAtEvent := range state { + if stateAtEvent.BeforeStateSnapshotNID == 0 { + // One of the prev events still has unknown state, so we haven't really + // solved the problem. + return false + } + } + return true } func (t *missingStateReq) lookupMissingStateViaState( diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 845533032..05cd686f4 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -150,7 +150,7 @@ func (r *Queryer) QueryMissingAuthPrevEvents( for _, prevEventID := range request.PrevEventIDs { state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 { + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) } } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index bb9f5dc62..fc75a2606 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -187,6 +187,12 @@ func (u *RoomUpdater) EventIDs( return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs) } +func (u *RoomUpdater) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return u.d.eventNIDs(ctx, u.txn, eventIDs) +} + func (u *RoomUpdater) StateAtEventIDs( ctx context.Context, eventIDs []string, ) ([]types.StateAtEvent, error) { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 127cd1f52..8319de265 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -603,6 +603,8 @@ func (d *Database) storeEvent( if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID()) + } else if err != nil { + return fmt.Errorf("d.EventsTable.InsertEvent: %w", err) } if err != nil { return fmt.Errorf("d.EventsTable.SelectEvent: %w", err) diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 5e1eebe98..5d52ccfcd 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -83,6 +83,10 @@ type StateKeyTuple struct { EventStateKeyNID EventStateKeyNID } +func (a StateKeyTuple) IsCreate() bool { + return a.EventTypeNID == MRoomCreateNID && a.EventStateKeyNID == EmptyStateKeyNID +} + // LessThan returns true if this state key is less than the other state key. // The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {