dendrite/roomserver/storage/shared/storage.go
Kegsay 19aa44ecae
Convert transactions/rooms table to share more code (#1063)
* Convert rooms table

* Convert transactions table

* Convert rooms table and factor out lots of functions

* I think you'll be needing this..
2020-05-26 18:23:39 +01:00

381 lines
12 KiB
Go

package shared
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
TransactionsTable tables.Transactions
}
// EventTypeNIDs implements state.RoomStateDatabase
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
}
// EventStateKeys implements query.RoomserverQueryAPIDatabase
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
}
// EventStateKeyNIDs implements state.RoomStateDatabase
func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
}
// StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
}
// EventNIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs)
}
// SetState implements input.EventDatabase
func (d *Database) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
}
// StateAtEventIDs implements input.EventDatabase
func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs)
}
// SnapshotNIDFromEventID implements state.RoomStateDatabase
func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID)
return stateNID, err
}
// EventIDs implements input.RoomEventDatabase
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
}
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
}
return d.Events(ctx, nids)
}
// RoomNID implements query.RoomserverQueryAPIDB
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, nil, roomID)
if err == sql.ErrNoRows {
return 0, nil
}
return roomNID, err
}
// RoomNIDExcludingStubs implements query.RoomserverQueryAPIDB
func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
roomNID, err = d.RoomNID(ctx, roomID)
if err != nil {
return
}
latestEvents, _, err := d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID)
if err != nil {
return
}
if len(latestEvents) == 0 {
roomNID = 0
return
}
return
}
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
func (d *Database) LatestEventIDs(
ctx context.Context, roomNID types.RoomNID,
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error {
var eventNIDs []types.EventNID
eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID)
if err != nil {
return err
}
references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs)
if err != nil {
return err
}
depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs)
if err != nil {
return err
}
return nil
})
return
}
func (d *Database) GetRoomVersionForRoom(
ctx context.Context, roomID string,
) (gomatrixserverlib.RoomVersion, error) {
return d.RoomsTable.SelectRoomVersionForRoomID(
ctx, nil, roomID,
)
}
func (d *Database) GetRoomVersionForRoomNID(
ctx context.Context, roomNID types.RoomNID,
) (gomatrixserverlib.RoomVersion, error) {
return d.RoomsTable.SelectRoomVersionForRoomNID(
ctx, roomNID,
)
}
// Events implements input.EventDatabase
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
if err != nil {
return nil, err
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
result := &results[i]
result.EventNID = eventJSON.EventNID
roomNID, err = d.EventsTable.SelectRoomNIDForEventNID(ctx, eventJSON.EventNID)
if err != nil {
return nil, err
}
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
}
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON.EventJSON, false, roomVersion,
)
if err != nil {
return nil, err
}
}
return results, nil
}
// GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.TransactionsTable.SelectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}
return eventID, err
}
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
err error
)
err = internal.WithTransaction(d.DB, func(txn *sql.Tx) error {
if txnAndSessionID != nil {
if err = d.TransactionsTable.InsertTransaction(
ctx, txn, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return err
}
}
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
// Note that the below logic depends on the m.room.create event being the
// first event that is persisted to the database when creating or joining a
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return err
}
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
return err
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
return err
}
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return err
}
}
if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx,
txn,
roomNID,
eventTypeNID,
eventStateKeyNID,
event.EventID(),
event.EventReference().EventSHA256,
authEventNIDs,
event.Depth(),
); err != nil {
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
}
if err != nil {
return err
}
}
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return err
}
return nil
})
if err != nil {
return 0, types.StateAtEvent{}, err
}
return roomNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: eventTypeNID,
EventStateKeyNID: eventStateKeyNID,
},
EventNID: eventNID,
},
}, nil
}
func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
// Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
roomNID, err = d.RoomsTable.InsertRoomNID(ctx, txn, roomID, roomVersion)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
}
}
return roomNID, err
}
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (eventTypeNID types.EventTypeNID, err error) {
// Check if we already have a numeric ID in the database.
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventTypeNID, err = d.EventTypesTable.InsertEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
}
}
return
}
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We don't have a numeric ID so insert one into the database.
eventStateKeyNID, err = d.EventStateKeysTable.InsertEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
// We raced with another insert so run the select again.
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
return eventStateKeyNID, err
}
func extractRoomVersionFromCreateEvent(event gomatrixserverlib.Event) (
gomatrixserverlib.RoomVersion, error,
) {
var err error
var roomVersion gomatrixserverlib.RoomVersion
// Look for m.room.create events.
if event.Type() != gomatrixserverlib.MRoomCreate {
return gomatrixserverlib.RoomVersion(""), nil
}
roomVersion = gomatrixserverlib.RoomVersionV1
var createContent gomatrixserverlib.CreateContent
// The m.room.create event contains an optional "room_version" key in
// the event content, so we need to unmarshal that first.
if err = json.Unmarshal(event.Content(), &createContent); err != nil {
return gomatrixserverlib.RoomVersion(""), err
}
// A room version was specified in the event content?
if createContent.RoomVersion != nil {
roomVersion = gomatrixserverlib.RoomVersion(*createContent.RoomVersion)
}
return roomVersion, err
}