From 4218311e6aa2a6b134e56f4206f9ef87d863419e Mon Sep 17 00:00:00 2001
From: riperiperi <rhy3756547@hotmail.com>
Date: Sat, 17 Feb 2024 00:41:30 +0000
Subject: [PATCH] Vulkan: Use push descriptors for uniform bindings when
 possible (#6154)

* Fix Push Descriptors

* Use push descriptor templates

* Use reserved bindings

* Formatting

* Disable when using MVK

("my heart will go on" starts playing as thousands of mac users shed a tear in unison)

* Introduce limit on push descriptor binding number

The bitmask used for updating push descriptors is ulong, so only 64 bindings can be tracked for now.

* Address feedback

* Fix logic for binding rejection

Should only offset limit when reserved bindings are less than the requested one.

* Workaround pascal and older nv bug

* Add GPU number detection for nvidia

* Only do workaround if it's valid to do so.
---
 src/Ryujinx.Graphics.Vulkan/Constants.cs      |   1 +
 .../DescriptorSetTemplate.cs                  | 102 ++++++++++++++++-
 .../DescriptorSetTemplateUpdater.cs           |  12 ++
 .../DescriptorSetUpdater.cs                   |  89 +++++++++++++--
 .../HardwareCapabilities.cs                   |   3 +
 src/Ryujinx.Graphics.Vulkan/PipelineBase.cs   |   2 +-
 .../PipelineLayoutCacheEntry.cs               |  40 +++++++
 .../ShaderCollection.cs                       | 103 +++++++++++++++++-
 src/Ryujinx.Graphics.Vulkan/Vendor.cs         |   3 +
 .../VulkanConfiguration.cs                    |   2 +-
 src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs |  54 ++++++++-
 11 files changed, 395 insertions(+), 16 deletions(-)

diff --git a/src/Ryujinx.Graphics.Vulkan/Constants.cs b/src/Ryujinx.Graphics.Vulkan/Constants.cs
index cd6122112..20ce65818 100644
--- a/src/Ryujinx.Graphics.Vulkan/Constants.cs
+++ b/src/Ryujinx.Graphics.Vulkan/Constants.cs
@@ -16,6 +16,7 @@ namespace Ryujinx.Graphics.Vulkan
         public const int MaxStorageBufferBindings = MaxStorageBuffersPerStage * MaxShaderStages;
         public const int MaxTextureBindings = MaxTexturesPerStage * MaxShaderStages;
         public const int MaxImageBindings = MaxImagesPerStage * MaxShaderStages;
+        public const int MaxPushDescriptorBinding = 64;
 
         public const ulong SparseBufferAlignment = 0x10000;
     }
diff --git a/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs
index 0c0004b95..b9abd8fcd 100644
--- a/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs
+++ b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplate.cs
@@ -1,19 +1,32 @@
 using Ryujinx.Graphics.GAL;
 using Silk.NET.Vulkan;
 using System;
+using System.Numerics;
 using System.Runtime.CompilerServices;
 
 namespace Ryujinx.Graphics.Vulkan
 {
     class DescriptorSetTemplate : IDisposable
     {
+        /// <summary>
+        /// Renderdoc seems to crash when doing a templated uniform update with count > 1 on a push descriptor.
+        /// When this is true, consecutive buffers are always updated individually.
+        /// </summary>
+        private const bool RenderdocPushCountBug = true;
+
         private readonly VulkanRenderer _gd;
         private readonly Device _device;
 
         public readonly DescriptorUpdateTemplate Template;
         public readonly int Size;
 
-        public unsafe DescriptorSetTemplate(VulkanRenderer gd, Device device, ResourceBindingSegment[] segments, PipelineLayoutCacheEntry plce, PipelineBindPoint pbp, int setIndex)
+        public unsafe DescriptorSetTemplate(
+            VulkanRenderer gd,
+            Device device,
+            ResourceBindingSegment[] segments,
+            PipelineLayoutCacheEntry plce,
+            PipelineBindPoint pbp,
+            int setIndex)
         {
             _gd = gd;
             _device = device;
@@ -137,6 +150,93 @@ namespace Ryujinx.Graphics.Vulkan
             Template = result;
         }
 
+        public unsafe DescriptorSetTemplate(
+            VulkanRenderer gd,
+            Device device,
+            ResourceDescriptorCollection descriptors,
+            long updateMask,
+            PipelineLayoutCacheEntry plce,
+            PipelineBindPoint pbp,
+            int setIndex)
+        {
+            _gd = gd;
+            _device = device;
+
+            // Create a template from the set usages. Assumes the descriptor set is updated in segment order then binding order.
+            int segmentCount = BitOperations.PopCount((ulong)updateMask);
+
+            DescriptorUpdateTemplateEntry* entries = stackalloc DescriptorUpdateTemplateEntry[segmentCount];
+            int entry = 0;
+            nuint structureOffset = 0;
+
+            void AddBinding(int binding, int count)
+            {
+                entries[entry++] = new DescriptorUpdateTemplateEntry()
+                {
+                    DescriptorType = DescriptorType.UniformBuffer,
+                    DstBinding = (uint)binding,
+                    DescriptorCount = (uint)count,
+                    Offset = structureOffset,
+                    Stride = (nuint)Unsafe.SizeOf<DescriptorBufferInfo>()
+                };
+
+                structureOffset += (nuint)(Unsafe.SizeOf<DescriptorBufferInfo>() * count);
+            }
+
+            int startBinding = 0;
+            int bindingCount = 0;
+
+            foreach (ResourceDescriptor descriptor in descriptors.Descriptors)
+            {
+                for (int i = 0; i < descriptor.Count; i++)
+                {
+                    int binding = descriptor.Binding + i;
+
+                    if ((updateMask & (1L << binding)) != 0)
+                    {
+                        if (bindingCount > 0 && (RenderdocPushCountBug || startBinding + bindingCount != binding))
+                        {
+                            AddBinding(startBinding, bindingCount);
+
+                            bindingCount = 0;
+                        }
+
+                        if (bindingCount == 0)
+                        {
+                            startBinding = binding;
+                        }
+
+                        bindingCount++;
+                    }
+                }
+            }
+
+            if (bindingCount > 0)
+            {
+                AddBinding(startBinding, bindingCount);
+            }
+
+            Size = (int)structureOffset;
+
+            var info = new DescriptorUpdateTemplateCreateInfo()
+            {
+                SType = StructureType.DescriptorUpdateTemplateCreateInfo,
+                DescriptorUpdateEntryCount = (uint)entry,
+                PDescriptorUpdateEntries = entries,
+
+                TemplateType = DescriptorUpdateTemplateType.PushDescriptorsKhr,
+                DescriptorSetLayout = plce.DescriptorSetLayouts[setIndex],
+                PipelineBindPoint = pbp,
+                PipelineLayout = plce.PipelineLayout,
+                Set = (uint)setIndex,
+            };
+
+            DescriptorUpdateTemplate result;
+            gd.Api.CreateDescriptorUpdateTemplate(device, &info, null, &result).ThrowOnError();
+
+            Template = result;
+        }
+
         public unsafe void Dispose()
         {
             _gd.Api.DestroyDescriptorUpdateTemplate(_device, Template, null);
diff --git a/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs
index 1eb9dce75..88db7e769 100644
--- a/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs
+++ b/src/Ryujinx.Graphics.Vulkan/DescriptorSetTemplateUpdater.cs
@@ -52,11 +52,23 @@ namespace Ryujinx.Graphics.Vulkan
             return new DescriptorSetTemplateWriter(new Span<byte>(_data.Pointer, template.Size));
         }
 
+        public DescriptorSetTemplateWriter Begin(int maxSize)
+        {
+            EnsureSize(maxSize);
+
+            return new DescriptorSetTemplateWriter(new Span<byte>(_data.Pointer, maxSize));
+        }
+
         public void Commit(VulkanRenderer gd, Device device, DescriptorSet set)
         {
             gd.Api.UpdateDescriptorSetWithTemplate(device, set, _activeTemplate.Template, _data.Pointer);
         }
 
+        public void CommitPushDescriptor(VulkanRenderer gd, CommandBufferScoped cbs, DescriptorSetTemplate template, PipelineLayout layout)
+        {
+            gd.PushDescriptorApi.CmdPushDescriptorSetWithTemplate(cbs.CommandBuffer, template.Template, layout, 0, _data.Pointer);
+        }
+
         public void Dispose()
         {
             _data?.Dispose();
diff --git a/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs b/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs
index 6615d8ce0..d40b201da 100644
--- a/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs
+++ b/src/Ryujinx.Graphics.Vulkan/DescriptorSetUpdater.cs
@@ -4,6 +4,7 @@ using Ryujinx.Graphics.Shader;
 using Silk.NET.Vulkan;
 using System;
 using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
 using CompareOp = Ryujinx.Graphics.GAL.CompareOp;
 using Format = Ryujinx.Graphics.GAL.Format;
 using SamplerCreateInfo = Ryujinx.Graphics.GAL.SamplerCreateInfo;
@@ -61,6 +62,8 @@ namespace Ryujinx.Graphics.Vulkan
         private BitMapStruct<Array2<long>> _storageSet;
         private BitMapStruct<Array2<long>> _uniformMirrored;
         private BitMapStruct<Array2<long>> _storageMirrored;
+        private readonly int[] _uniformSetPd;
+        private int _pdSequence = 1;
 
         private bool _updateDescriptorCacheCbIndex;
 
@@ -106,6 +109,8 @@ namespace Ryujinx.Graphics.Vulkan
             _bufferTextures = new BufferView[Constants.MaxTexturesPerStage];
             _bufferImages = new BufferView[Constants.MaxImagesPerStage];
 
+            _uniformSetPd = new int[Constants.MaxUniformBufferBindings];
+
             var initialImageInfo = new DescriptorImageInfo
             {
                 ImageLayout = ImageLayout.General,
@@ -193,6 +198,7 @@ namespace Ryujinx.Graphics.Vulkan
                         if (BindingOverlaps(ref info, bindingOffset, offset, size))
                         {
                             _uniformSet.Clear(binding);
+                            _uniformSetPd[binding] = 0;
                             SignalDirty(DirtyFlags.Uniform);
                         }
                     }
@@ -223,8 +229,30 @@ namespace Ryujinx.Graphics.Vulkan
             });
         }
 
-        public void SetProgram(ShaderCollection program)
+        public void AdvancePdSequence()
         {
+            if (++_pdSequence == 0)
+            {
+                _pdSequence = 1;
+            }
+        }
+
+        public void SetProgram(CommandBufferScoped cbs, ShaderCollection program, bool isBound)
+        {
+            if (!program.HasSameLayout(_program))
+            {
+                // When the pipeline layout changes, push descriptor bindings are invalidated.
+
+                AdvancePdSequence();
+
+                if (_gd.IsNvidiaPreTuring && !program.UsePushDescriptors && _program?.UsePushDescriptors == true && isBound)
+                {
+                    // On older nvidia GPUs, we need to clear out the active push descriptor bindings when switching
+                    // to normal descriptors. Keeping them bound can prevent buffers from binding properly in future.
+                    ClearAndBindUniformBufferPd(cbs);
+                }
+            }
+
             _program = program;
             _updateDescriptorCacheCbIndex = true;
             _dirty = DirtyFlags.All;
@@ -402,6 +430,7 @@ namespace Ryujinx.Graphics.Vulkan
                 if (!currentBufferRef.Equals(newRef) || currentInfo.Range != info.Range)
                 {
                     _uniformSet.Clear(index);
+                    _uniformSetPd[index] = 0;
 
                     currentInfo = info;
                     currentBufferRef = newRef;
@@ -671,15 +700,19 @@ namespace Ryujinx.Graphics.Vulkan
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private void UpdateAndBindUniformBufferPd(CommandBufferScoped cbs, PipelineBindPoint pbp)
         {
+            int sequence = _pdSequence;
             var bindingSegments = _program.BindingSegments[PipelineBase.UniformSetIndex];
             var dummyBuffer = _dummyBuffer?.GetBuffer();
 
+            long updatedBindings = 0;
+            DescriptorSetTemplateWriter writer = _templateUpdater.Begin(32 * Unsafe.SizeOf<DescriptorBufferInfo>());
+
             foreach (ResourceBindingSegment segment in bindingSegments)
             {
                 int binding = segment.Binding;
                 int count = segment.Count;
 
-                bool doUpdate = false;
+                ReadOnlySpan<DescriptorBufferInfo> uniformBuffers = _uniformBuffers;
 
                 for (int i = 0; i < count; i++)
                 {
@@ -688,17 +721,58 @@ namespace Ryujinx.Graphics.Vulkan
                     if (_uniformSet.Set(index))
                     {
                         ref BufferRef buffer = ref _uniformBufferRefs[index];
-                        UpdateBuffer(cbs, ref _uniformBuffers[index], ref buffer, dummyBuffer, true);
-                        doUpdate = true;
+
+                        bool mirrored = UpdateBuffer(cbs, ref _uniformBuffers[index], ref buffer, dummyBuffer, true);
+
+                        _uniformMirrored.Set(index, mirrored);
+                    }
+
+                    if (_uniformSetPd[index] != sequence)
+                    {
+                        // Need to set this push descriptor (even if the buffer binding has not changed)
+
+                        _uniformSetPd[index] = sequence;
+                        updatedBindings |= 1L << index;
+
+                        writer.Push(MemoryMarshal.CreateReadOnlySpan(ref _uniformBuffers[index], 1));
                     }
                 }
+            }
 
-                if (doUpdate)
+            if (updatedBindings > 0)
+            {
+                DescriptorSetTemplate template = _program.GetPushDescriptorTemplate(updatedBindings);
+                _templateUpdater.CommitPushDescriptor(_gd, cbs, template, _program.PipelineLayout);
+            }
+        }
+
+        private void ClearAndBindUniformBufferPd(CommandBufferScoped cbs)
+        {
+            var bindingSegments = _program.BindingSegments[PipelineBase.UniformSetIndex];
+
+            long updatedBindings = 0;
+            DescriptorSetTemplateWriter writer = _templateUpdater.Begin(32 * Unsafe.SizeOf<DescriptorBufferInfo>());
+
+            foreach (ResourceBindingSegment segment in bindingSegments)
+            {
+                int binding = segment.Binding;
+                int count = segment.Count;
+
+                for (int i = 0; i < count; i++)
                 {
-                    ReadOnlySpan<DescriptorBufferInfo> uniformBuffers = _uniformBuffers;
-                    UpdateBuffers(cbs, pbp, binding, uniformBuffers.Slice(binding, count), DescriptorType.UniformBuffer);
+                    int index = binding + i;
+                    updatedBindings |= 1L << index;
+
+                    var bufferInfo = new DescriptorBufferInfo();
+                    writer.Push(MemoryMarshal.CreateReadOnlySpan(ref bufferInfo, 1));
                 }
             }
+
+            if (updatedBindings > 0)
+            {
+                DescriptorSetTemplate template = _program.GetPushDescriptorTemplate(updatedBindings);
+                _templateUpdater.CommitPushDescriptor(_gd, cbs, template, _program.PipelineLayout);
+            }
         }
 
         private void Initialize(CommandBufferScoped cbs, int setIndex, DescriptorSetCollection dsc)
@@ -724,6 +798,7 @@ namespace Ryujinx.Graphics.Vulkan
 
             _uniformSet.Clear();
             _storageSet.Clear();
+            AdvancePdSequence();
         }
 
         private static void SwapBuffer(BufferRef[] list, Auto<DisposableBuffer> from, Auto<DisposableBuffer> to)
diff --git a/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs b/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs
index 98c777eed..b6694bcb3 100644
--- a/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs
+++ b/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs
@@ -34,6 +34,7 @@ namespace Ryujinx.Graphics.Vulkan
         public readonly bool SupportsMultiView;
         public readonly bool SupportsNullDescriptors;
         public readonly bool SupportsPushDescriptors;
+        public readonly uint MaxPushDescriptors;
         public readonly bool SupportsPrimitiveTopologyListRestart;
         public readonly bool SupportsPrimitiveTopologyPatchListRestart;
         public readonly bool SupportsTransformFeedback;
@@ -71,6 +72,7 @@ namespace Ryujinx.Graphics.Vulkan
             bool supportsMultiView,
             bool supportsNullDescriptors,
             bool supportsPushDescriptors,
+            uint maxPushDescriptors,
             bool supportsPrimitiveTopologyListRestart,
             bool supportsPrimitiveTopologyPatchListRestart,
             bool supportsTransformFeedback,
@@ -107,6 +109,7 @@ namespace Ryujinx.Graphics.Vulkan
             SupportsMultiView = supportsMultiView;
             SupportsNullDescriptors = supportsNullDescriptors;
             SupportsPushDescriptors = supportsPushDescriptors;
+            MaxPushDescriptors = maxPushDescriptors;
             SupportsPrimitiveTopologyListRestart = supportsPrimitiveTopologyListRestart;
             SupportsPrimitiveTopologyPatchListRestart = supportsPrimitiveTopologyPatchListRestart;
             SupportsTransformFeedback = supportsTransformFeedback;
diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs b/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
index 3aef1317a..3b3f59259 100644
--- a/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PipelineBase.cs
@@ -976,7 +976,7 @@ namespace Ryujinx.Graphics.Vulkan
 
             _program = internalProgram;
 
-            _descriptorSetUpdater.SetProgram(internalProgram);
+            _descriptorSetUpdater.SetProgram(Cbs, internalProgram, _currentPipelineHandle != 0);
 
             _newState.PipelineLayout = internalProgram.PipelineLayout;
             _newState.StagesCount = (uint)stages.Length;
diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineLayoutCacheEntry.cs b/src/Ryujinx.Graphics.Vulkan/PipelineLayoutCacheEntry.cs
index 2840dda0f..f388d9e88 100644
--- a/src/Ryujinx.Graphics.Vulkan/PipelineLayoutCacheEntry.cs
+++ b/src/Ryujinx.Graphics.Vulkan/PipelineLayoutCacheEntry.cs
@@ -31,6 +31,11 @@ namespace Ryujinx.Graphics.Vulkan
         private int _dsLastCbIndex;
         private int _dsLastSubmissionCount;
 
+        private readonly Dictionary<long, DescriptorSetTemplate> _pdTemplates;
+        private readonly ResourceDescriptorCollection _pdDescriptors;
+        private long _lastPdUsage;
+        private DescriptorSetTemplate _lastPdTemplate;
+
         private PipelineLayoutCacheEntry(VulkanRenderer gd, Device device, int setsCount)
         {
             _gd = gd;
@@ -72,6 +77,12 @@ namespace Ryujinx.Graphics.Vulkan
 
                 _consumedDescriptorsPerSet[setIndex] = count;
             }
+
+            if (usePushDescriptors)
+            {
+                _pdDescriptors = setDescriptors[0];
+                _pdTemplates = new();
+            }
         }
 
         public void UpdateCommandBufferIndex(int commandBufferIndex)
@@ -143,10 +154,39 @@ namespace Ryujinx.Graphics.Vulkan
             return output[..count];
         }
 
+        public DescriptorSetTemplate GetPushDescriptorTemplate(PipelineBindPoint pbp, long updateMask)
+        {
+            if (_lastPdUsage == updateMask && _lastPdTemplate != null)
+            {
+                // Most likely result is that it asks to update the same buffers.
+                return _lastPdTemplate;
+            }
+
+            if (!_pdTemplates.TryGetValue(updateMask, out DescriptorSetTemplate template))
+            {
+                template = new DescriptorSetTemplate(_gd, _device, _pdDescriptors, updateMask, this, pbp, 0);
+
+                _pdTemplates.Add(updateMask, template);
+            }
+
+            _lastPdUsage = updateMask;
+            _lastPdTemplate = template;
+
+            return template;
+        }
+
         protected virtual unsafe void Dispose(bool disposing)
         {
             if (disposing)
             {
+                if (_pdTemplates != null)
+                {
+                    foreach (DescriptorSetTemplate template in _pdTemplates.Values)
+                    {
+                        template.Dispose();
+                    }
+                }
+
                 for (int i = 0; i < _dsCache.Length; i++)
                 {
                     for (int j = 0; j < _dsCache[i].Length; j++)
diff --git a/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs b/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
index 7c25c6d14..3c35a6f0e 100644
--- a/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
+++ b/src/Ryujinx.Graphics.Vulkan/ShaderCollection.cs
@@ -108,18 +108,25 @@ namespace Ryujinx.Graphics.Vulkan
 
             _shaders = internalShaders;
 
-            bool usePushDescriptors = !isMinimal && VulkanConfiguration.UsePushDescriptors && _gd.Capabilities.SupportsPushDescriptors;
+            bool usePushDescriptors = !isMinimal &&
+                VulkanConfiguration.UsePushDescriptors &&
+                _gd.Capabilities.SupportsPushDescriptors &&
+                !IsCompute &&
+                CanUsePushDescriptors(gd, resourceLayout, IsCompute);
 
-            _plce = gd.PipelineLayoutCache.GetOrCreate(gd, device, resourceLayout.Sets, usePushDescriptors);
+            ReadOnlyCollection<ResourceDescriptorCollection> sets = usePushDescriptors ?
+                BuildPushDescriptorSets(gd, resourceLayout.Sets) : resourceLayout.Sets;
+
+            _plce = gd.PipelineLayoutCache.GetOrCreate(gd, device, sets, usePushDescriptors);
 
             HasMinimalLayout = isMinimal;
             UsePushDescriptors = usePushDescriptors;
 
             Stages = stages;
 
-            ClearSegments = BuildClearSegments(resourceLayout.Sets);
+            ClearSegments = BuildClearSegments(sets);
             BindingSegments = BuildBindingSegments(resourceLayout.SetUsages);
-            Templates = BuildTemplates();
+            Templates = BuildTemplates(usePushDescriptors);
 
             _compileTask = Task.CompletedTask;
             _firstBackgroundUse = false;
@@ -139,6 +146,76 @@ namespace Ryujinx.Graphics.Vulkan
             _firstBackgroundUse = !fromCache;
         }
 
+        private static bool CanUsePushDescriptors(VulkanRenderer gd, ResourceLayout layout, bool isCompute)
+        {
+            // If binding 3 is immediately used, use an alternate set of reserved bindings.
+            ReadOnlyCollection<ResourceUsage> uniformUsage = layout.SetUsages[0].Usages;
+            bool hasBinding3 = uniformUsage.Any(x => x.Binding == 3);
+            int[] reserved = isCompute ? Array.Empty<int>() : gd.GetPushDescriptorReservedBindings(hasBinding3);
+
+            // Can't use any of the reserved usages.
+            for (int i = 0; i < uniformUsage.Count; i++)
+            {
+                var binding = uniformUsage[i].Binding;
+
+                if (reserved.Contains(binding) ||
+                    binding >= Constants.MaxPushDescriptorBinding ||
+                    binding >= gd.Capabilities.MaxPushDescriptors + reserved.Count(id => id < binding))
+                {
+                    return false;
+                }
+            }
+
+            return true;
+        }
+
+        private static ReadOnlyCollection<ResourceDescriptorCollection> BuildPushDescriptorSets(
+            VulkanRenderer gd,
+            ReadOnlyCollection<ResourceDescriptorCollection> sets)
+        {
+            // The reserved bindings were selected when determining if push descriptors could be used.
+            int[] reserved = gd.GetPushDescriptorReservedBindings(false);
+
+            var result = new ResourceDescriptorCollection[sets.Count];
+
+            for (int i = 0; i < sets.Count; i++)
+            {
+                if (i == 0)
+                {
+                    // Push descriptors apply here. Remove reserved bindings.
+                    ResourceDescriptorCollection original = sets[i];
+
+                    var pdUniforms = new ResourceDescriptor[original.Descriptors.Count];
+                    int j = 0;
+
+                    foreach (ResourceDescriptor descriptor in original.Descriptors)
+                    {
+                        if (reserved.Contains(descriptor.Binding))
+                        {
+                            // If the binding is reserved, set its descriptor count to 0.
+                            pdUniforms[j++] = new ResourceDescriptor(
+                                descriptor.Binding,
+                                0,
+                                descriptor.Type,
+                                descriptor.Stages);
+                        }
+                        else
+                        {
+                            pdUniforms[j++] = descriptor;
+                        }
+                    }
+
+                    result[i] = new ResourceDescriptorCollection(new(pdUniforms));
+                }
+                else
+                {
+                    result[i] = sets[i];
+                }
+            }
+
+            return new(result);
+        }
+
         private static ResourceBindingSegment[][] BuildClearSegments(ReadOnlyCollection<ResourceDescriptorCollection> sets)
         {
             ResourceBindingSegment[][] segments = new ResourceBindingSegment[sets.Count][];
@@ -243,12 +320,18 @@ namespace Ryujinx.Graphics.Vulkan
             return segments;
         }
 
-        private DescriptorSetTemplate[] BuildTemplates()
+        private DescriptorSetTemplate[] BuildTemplates(bool usePushDescriptors)
         {
             var templates = new DescriptorSetTemplate[BindingSegments.Length];
 
             for (int setIndex = 0; setIndex < BindingSegments.Length; setIndex++)
             {
+                if (usePushDescriptors && setIndex == 0)
+                {
+                    // Push descriptors get updated using templates owned by the pipeline layout.
+                    continue;
+                }
+
                 ResourceBindingSegment[] segments = BindingSegments[setIndex];
 
                 if (segments != null && segments.Length > 0)
@@ -433,6 +516,11 @@ namespace Ryujinx.Graphics.Vulkan
             return null;
         }
 
+        public DescriptorSetTemplate GetPushDescriptorTemplate(long updateMask)
+        {
+            return _plce.GetPushDescriptorTemplate(IsCompute ? PipelineBindPoint.Compute : PipelineBindPoint.Graphics, updateMask);
+        }
+
         public void AddComputePipeline(ref SpecData key, Auto<DisposablePipeline> pipeline)
         {
             (_computePipelineCache ??= new()).Add(ref key, pipeline);
@@ -493,6 +581,11 @@ namespace Ryujinx.Graphics.Vulkan
             return _plce.GetNewDescriptorSetCollection(setIndex, out isNew);
         }
 
+        public bool HasSameLayout(ShaderCollection other)
+        {
+            return other != null && _plce == other._plce;
+        }
+
         protected virtual void Dispose(bool disposing)
         {
             if (disposing)
diff --git a/src/Ryujinx.Graphics.Vulkan/Vendor.cs b/src/Ryujinx.Graphics.Vulkan/Vendor.cs
index 2d2f17b25..ff841dec9 100644
--- a/src/Ryujinx.Graphics.Vulkan/Vendor.cs
+++ b/src/Ryujinx.Graphics.Vulkan/Vendor.cs
@@ -20,6 +20,9 @@ namespace Ryujinx.Graphics.Vulkan
         [GeneratedRegex("Radeon (((HD|R(5|7|9|X)) )?((M?[2-6]\\d{2}(\\D|$))|([7-8]\\d{3}(\\D|$))|Fury|Nano))|(Pro Duo)")]
         public static partial Regex AmdGcnRegex();
 
+        [GeneratedRegex("NVIDIA GeForce (R|G)?TX? (\\d{3}\\d?)M?")]
+        public static partial Regex NvidiaConsumerClassRegex();
+
         public static Vendor FromId(uint id)
         {
             return id switch
diff --git a/src/Ryujinx.Graphics.Vulkan/VulkanConfiguration.cs b/src/Ryujinx.Graphics.Vulkan/VulkanConfiguration.cs
index a1fdc4aed..596c0e176 100644
--- a/src/Ryujinx.Graphics.Vulkan/VulkanConfiguration.cs
+++ b/src/Ryujinx.Graphics.Vulkan/VulkanConfiguration.cs
@@ -4,7 +4,7 @@ namespace Ryujinx.Graphics.Vulkan
     {
         public const bool UseFastBufferUpdates = true;
         public const bool UseUnsafeBlit = true;
-        public const bool UsePushDescriptors = false;
+        public const bool UsePushDescriptors = true;
 
         public const bool ForceD24S8Unsupported = false;
         public const bool ForceRGB16IntFloatUnsupported = false;
diff --git a/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs b/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs
index 48f05fa19..6aa46b79a 100644
--- a/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs
+++ b/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs
@@ -76,10 +76,15 @@ namespace Ryujinx.Graphics.Vulkan
         private readonly Func<string[]> _getRequiredExtensions;
         private readonly string _preferredGpuId;
 
+        private int[] _pdReservedBindings;
+        private readonly static int[] _pdReservedBindingsNvn = { 3, 18, 21, 36, 30 };
+        private readonly static int[] _pdReservedBindingsOgl = { 17, 18, 34, 35, 36 };
+
         internal Vendor Vendor { get; private set; }
         internal bool IsAmdWindows { get; private set; }
         internal bool IsIntelWindows { get; private set; }
         internal bool IsAmdGcn { get; private set; }
+        internal bool IsNvidiaPreTuring { get; private set; }
         internal bool IsMoltenVk { get; private set; }
         internal bool IsTBDR { get; private set; }
         internal bool IsSharedMemory { get; private set; }
@@ -191,6 +196,19 @@ namespace Ryujinx.Graphics.Vulkan
                 SType = StructureType.PhysicalDevicePortabilitySubsetPropertiesKhr,
             };
 
+            bool supportsPushDescriptors = _physicalDevice.IsDeviceExtensionPresent(KhrPushDescriptor.ExtensionName);
+
+            PhysicalDevicePushDescriptorPropertiesKHR propertiesPushDescriptor = new PhysicalDevicePushDescriptorPropertiesKHR()
+            {
+                SType = StructureType.PhysicalDevicePushDescriptorPropertiesKhr
+            };
+
+            if (supportsPushDescriptors)
+            {
+                propertiesPushDescriptor.PNext = properties2.PNext;
+                properties2.PNext = &propertiesPushDescriptor;
+            }
+
             PhysicalDeviceFeatures2 features2 = new()
             {
                 SType = StructureType.PhysicalDeviceFeatures2,
@@ -320,7 +338,8 @@ namespace Ryujinx.Graphics.Vulkan
                 _physicalDevice.IsDeviceExtensionPresent(ExtExtendedDynamicState.ExtensionName),
                 features2.Features.MultiViewport && !(IsMoltenVk && Vendor == Vendor.Amd), // Workaround for AMD on MoltenVK issue
                 featuresRobustness2.NullDescriptor || IsMoltenVk,
-                _physicalDevice.IsDeviceExtensionPresent(KhrPushDescriptor.ExtensionName),
+                supportsPushDescriptors && !IsMoltenVk,
+                propertiesPushDescriptor.MaxPushDescriptors,
                 featuresPrimitiveTopologyListRestart.PrimitiveTopologyListRestart,
                 featuresPrimitiveTopologyListRestart.PrimitiveTopologyPatchListRestart,
                 supportsTransformFeedback,
@@ -400,6 +419,25 @@ namespace Ryujinx.Graphics.Vulkan
             _initialized = true;
         }
 
+        internal int[] GetPushDescriptorReservedBindings(bool isOgl)
+        {
+            // The first call of this method determines what push descriptor layout is used for all shaders on this renderer.
+            // This is chosen to minimize shaders that can't fit their uniforms on the device's max number of push descriptors.
+            if (_pdReservedBindings == null)
+            {
+                if (Capabilities.MaxPushDescriptors <= Constants.MaxUniformBuffersPerStage * 2)
+                {
+                    _pdReservedBindings = isOgl ? _pdReservedBindingsOgl : _pdReservedBindingsNvn;
+                }
+                else
+                {
+                    _pdReservedBindings = Array.Empty<int>();
+                }
+            }
+
+            return _pdReservedBindings;
+        }
+
         public BufferHandle CreateBuffer(int size, BufferAccess access)
         {
             return BufferManager.CreateWithHandle(this, size, access.HasFlag(BufferAccess.SparseCompatible), access.Convert(), default, access == BufferAccess.Stream);
@@ -716,6 +754,20 @@ namespace Ryujinx.Graphics.Vulkan
 
             IsAmdGcn = !IsMoltenVk && Vendor == Vendor.Amd && VendorUtils.AmdGcnRegex().IsMatch(GpuRenderer);
 
+            if (Vendor == Vendor.Nvidia)
+            {
+                var match = VendorUtils.NvidiaConsumerClassRegex().Match(GpuRenderer);
+
+                if (match != null && int.TryParse(match.Groups[2].Value, out int gpuNumber))
+                {
+                    IsNvidiaPreTuring = gpuNumber < 2000;
+                }
+                else if (GpuDriver.Contains("TITAN") && !GpuDriver.Contains("RTX"))
+                {
+                    IsNvidiaPreTuring = true;
+                }
+            }
+
             Logger.Notice.Print(LogClass.Gpu, $"{GpuVendor} {GpuRenderer} ({GpuVersion})");
         }