diff --git a/src/xrt/auxiliary/bindings/bindings.py b/src/xrt/auxiliary/bindings/bindings.py index f66b269ee..3039629ef 100755 --- a/src/xrt/auxiliary/bindings/bindings.py +++ b/src/xrt/auxiliary/bindings/bindings.py @@ -262,6 +262,101 @@ class Identifier: self.components = component_list self.dpad = dpad return + + +class FeatureSet: + """An AND of requirements (versions and/or extensions) under which a binding becomes available""" + + def __init__(self, required_extensions=None, required_version=None): + self.required_extensions = frozenset(required_extensions if required_extensions is not None else []) + self.required_version = required_version if required_version is not None else {"major": "0", "minor": "0"} + + def as_tuple(self): + return (self.required_version["major"], + self.required_version["minor"], + sorted(self.required_extensions)) + + def __str__(self) -> str: + return f"{self.as_tuple()}" + + def is_more_restrictive_than(self, other): + if other.required_extensions.issuperset(self.required_extensions): + # other requires extensions we don't + return False + if check_promoted(other.required_version) and not check_promoted(self.required_version): + # other bounds version, we don't + return False + if check_promoted(other.required_version) and check_promoted(self.required_version): + if other.required_version["major"] != self.required_version["major"]: + # different major versions - not fully implemented, but this seems right + return False + if int(other.required_version["minor"]) > int(self.required_version["minor"]): + # other has a higher lower bound on minor version than us + return False + + return True + + def and_also(self, other): + result = copy.deepcopy(self) + result.required_extensions = self.required_extensions | other.required_extensions + if not check_promoted(result.required_version): + result.required_version = other.required_version + elif check_promoted(other.required_version): + if result.required_version["major"] != other.required_version["major"]: + raise NotImplementedError("Major version mismatch not handled") + if int(result.required_version["minor"]) < int(other.required_version["minor"]): + result.required_version["minor"] = other.required_version["minor"] + + return result + + +class Availability: + """An OR of FeatureSets, where any one of them being satisfied means a binding becomes available""" + + def __init__(self, feature_sets, optimize=True): + if not optimize: + self.feature_sets = set(feature_sets) + return + + self.feature_sets = set() + for feature_set in feature_sets: + self.add_in_place(feature_set) + + def __str__(self) -> str: + return f"{[str(fs) for fs in sorted(self.feature_sets, key=FeatureSet.as_tuple)]}" + + """Add an additional way for this availability to be satisfied""" + def add_in_place(self, new_feature_set): + for existing_feature in list(self.feature_sets): + if existing_feature.is_more_restrictive_than(new_feature_set): + self.feature_sets.remove(existing_feature) + elif new_feature_set.is_more_restrictive_than(existing_feature): + return + self.feature_sets.add(new_feature_set) + + """Add an additional restriction to all feature sets""" + def intersect_with_feature(self, feature_set): + result = Availability(feature_sets=[]) + for existing_feature in self.feature_sets: + result.add_in_place(existing_feature.and_also(feature_set)) + return result + + """Combine two availabilities into one that is satisfied if either is""" + def union(self, other): + result = copy.deepcopy(self) + for feature_set in other.feature_sets: + result.add_in_place(feature_set) + + """Combine two availabilities into one that is satisfied if both are. + Note that this acts as a cartesian product followed by an at least n^2 op on that. + """ + def intersection(self, other): + features = [] + for feature_set in other.feature_sets: + inter = self.intersect_with_feature(feature_set) + features.extend(inter.feature_sets) + + return Availability(features) class Profile: @@ -368,6 +463,20 @@ class Profile: self.components += identifier.components self.components = sorted(self.components, key=attrgetter("steamvr_path")) + def availability(self): + result = Availability(feature_sets=[]) + has_requirements = False + if check_promoted(self.openxr_version_promoted): + result.add_in_place(FeatureSet(required_version=self.openxr_version_promoted)) + has_requirements = True + if self.extension_name is not None: + result.add_in_place(FeatureSet(required_extensions=[self.extension_name])) + has_requirements = True + if not has_requirements: + result.add_in_place(FeatureSet()) + + return result + oxr_verify_extension_status_struct_name = "oxr_verify_extension_status" @@ -456,9 +565,9 @@ if_strcmp = '''if (strcmp(str, "{check}") == 0) {{ def check_promoted(openxr_version_promoted): # If required version is 0.0, we can skip checking that the instance uses a more recent version - return openxr_version_promoted is not None and openxr_version_promoted["major"] != '0' and openxr_version_promoted["minor"] != '0' + return openxr_version_promoted is not None and not (openxr_version_promoted["major"] == '0' and openxr_version_promoted["minor"] == '0') -def write_verify_switch_body(f, dict_of_lists, profile, profile_name, ext_name, tab_char): +def write_verify_switch_body(f, dict_of_lists, profile, profile_name, tab_char): """Generate function to check if a string is in a set of strings. Input is a file to write the code into, a dict where keys are length and the values are lists of strings of that length. And a suffix if any.""" @@ -470,53 +579,59 @@ def write_verify_switch_body(f, dict_of_lists, profile, profile_name, ext_name, f.write(f"{{\n{tab_char}\t\t\tbreak;\n{tab_char}\t\t}}\n") f.write(f"{tab_char}\tdefault: break;\n{tab_char}\t}}\n") -def write_verify_func_switch(f, dict_of_lists, profile, profile_name, ext_name): +def write_verify_func_switch(f, dict_of_lists, profile, profile_name, availability): """Generate function to check if a string is in a set of strings. Input is a file to write the code into, a dict where keys are length and the values are lists of strings of that length. And a suffix if any.""" if len(dict_of_lists) == 0: return - is_ext = ext_name is not None and len(ext_name) > 0 - is_promoted = check_promoted(profile.openxr_version_promoted) - f.write(f"\t// generated from: {profile_name}\n") # Example: pico neo 3 can be enabled by either enabling XR_BD_controller_interaction ext or using OpenXR 1.1+. # Disabling OXR_HAVE_BD_controller_interaction should NOT remove pico neo from OpenXR 1.1+ (it makes "exts->BD_controller_interaction" invalid C code). # Therefore separate code blocks for ext and version checks generated to avoid ifdef hell. - if is_ext: - f.write(f'#ifdef OXR_HAVE_{profile.extension_name}\n') - f.write(f"\tif (exts->{ext_name}) {{\n") - write_verify_switch_body(f, dict_of_lists, profile, profile_name, ext_name, '\t') - f.write("\t}\n") - f.write(f'#endif // OXR_HAVE_{profile.extension_name}\n') + feature_sets = sorted(availability.feature_sets, key=FeatureSet.as_tuple) + for feature_set in feature_sets: + requires_version = check_promoted(feature_set.required_version) + requires_extensions = bool(feature_set.required_extensions) - # The split into "is_promoted" and "not is_ext and not is_promoted" cases is not strictly necessary as we could generate the version check for both cases. - # For the "not is_ext and not is_promoted" case this would generate "if (openxr_version >= XR_MAKE_VERSION(0, 0, 0))", which we avoid doing here by this split. - if is_promoted: - f.write(f'\tif (openxr_version >= XR_MAKE_VERSION({profile.openxr_version_promoted["major"]}, {profile.openxr_version_promoted["minor"]}, 0)) {{\n') - write_verify_switch_body(f, dict_of_lists, profile, profile_name, ext_name, '\t') - f.write("\t}\n") + tab_char = '' + closing = [] - if not is_ext and not is_promoted: - write_verify_switch_body(f, dict_of_lists, profile, profile_name, ext_name, '') + if requires_version: + tab_char += '\t' + f.write(f'{tab_char}if (openxr_version >= XR_MAKE_VERSION({feature_set.required_version["major"]}, {feature_set.required_version["minor"]}, 0)) {{\n') + closing.append(f'{tab_char}}}\n') -def write_verify_func_body(f, profile, dict_name): + if requires_extensions: + tab_char += '\t' + exts = sorted(feature_set.required_extensions) + ext_defines = ' && '.join(f'defined(OXR_HAVE_{ext})' for ext in exts) + f.write(f'#if {ext_defines}\n') + f.write(f'{tab_char}if ('+' && '.join(f'exts->{ext}' for ext in exts)+') {\n') + closing.append(f'{tab_char}}}\n#endif // {ext_defines}\n') + + write_verify_switch_body(f, dict_of_lists, profile, profile_name, tab_char) + + for closer in reversed(closing): + f.write(closer) + +def write_verify_func_body(f, profile, dict_name, availability): if profile is None or dict_name is None or len(dict_name) == 0: return write_verify_func_switch(f, getattr( - profile, dict_name), profile, profile.name, profile.extension_name) + profile, dict_name), profile, profile.name, availability) if profile.parent_profiles is None: return for pp in sorted(profile.parent_profiles, key=attrgetter("name")): - write_verify_func_body(f, pp, dict_name) + write_verify_func_body(f, pp, dict_name, availability.intersection(pp.availability())) def write_verify_func(f, profile, dict_name, suffix): write_verify_func_begin( f, f"oxr_verify_{profile.validation_func_name}{suffix}") - write_verify_func_body(f, profile, dict_name) + write_verify_func_body(f, profile, dict_name, profile.availability()) write_verify_func_end(f)