It will check whether the values specified are written correctly, and whether all enum values are covered, when discriminator is a pre-defined enum type
Signed-off-by: Wenchao Xia <xiaw...@linux.vnet.ibm.com> Reviewed-by: Eric Blake <ebl...@redhat.com> --- scripts/qapi-visit.py | 17 +++++++++++++++++ scripts/qapi.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 0 deletions(-) diff --git a/scripts/qapi-visit.py b/scripts/qapi-visit.py index 65f1a54..c0efb5f 100644 --- a/scripts/qapi-visit.py +++ b/scripts/qapi-visit.py @@ -255,6 +255,23 @@ def generate_visit_union(expr): assert not base return generate_visit_anon_union(name, members) + # If discriminator is specified and it is a pre-defined enum in schema, + # check its correctness + enum_define = discriminator_find_enum_define(expr) + if enum_define: + for key in members: + if not key in enum_define["enum_values"]: + sys.stderr.write("Discriminator value '%s' is not found in " + "enum '%s'\n" % + (key, enum_define["enum_name"])) + sys.exit(1) + for key in enum_define["enum_values"]: + if not key in members: + sys.stderr.write("Enum value '%s' is not covered by a branch " + "of union '%s'\n" % + (key, name)) + sys.exit(1) + ret = generate_visit_enum('%sKind' % name, members.keys()) if base: diff --git a/scripts/qapi.py b/scripts/qapi.py index cf34768..0a3ab80 100644 --- a/scripts/qapi.py +++ b/scripts/qapi.py @@ -385,3 +385,34 @@ def guardend(name): ''', name=guardname(name)) + +# This function can be used to check whether "base" is valid +def find_base_fields(base): + base_struct_define = find_struct(base) + if not base_struct_define: + return None + return base_struct_define.get('data') + +# Return the discriminator enum define, if discriminator is specified in +# @expr and it is a pre-defined enum type +def discriminator_find_enum_define(expr): + discriminator = expr.get('discriminator') + base = expr.get('base') + + # Only support discriminator when base present + if not (discriminator and base): + return None + + base_fields = find_base_fields(base) + + if not base_fields: + raise StandardError("Base '%s' is not a valid type\n" + % base) + + discriminator_type = base_fields.get(discriminator) + + if not discriminator_type: + raise StandardError("Discriminator '%s' not found in schema\n" + % discriminator) + + return find_enum(discriminator_type) -- 1.7.1