zoukankan      html  css  js  c++  java
  • protoc-gen-lua支持嵌套类型

    #!/usr/bin/env python
    # -*- encoding:utf8 -*-
    # protoc-gen-erl
    # Google's Protocol Buffers project, ported to lua.
    # https://code.google.com/p/protoc-gen-lua/
    #
    # Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
    # All rights reserved.
    #
    # Use, modification and distribution are subject to the "New BSD License"
    # as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
    
    import sys
    import os.path as path
    from cStringIO import StringIO
    
    import plugin_pb2
    import google.protobuf.descriptor_pb2 as descriptor_pb2
    
    _packages = {}
    _files = {}
    _message = {}
    
    FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto
    
    if sys.platform == "win32":
        import msvcrt, os
        msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
        msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
    
    class CppType:
        CPPTYPE_INT32       = 1
        CPPTYPE_INT64       = 2
        CPPTYPE_UINT32      = 3
        CPPTYPE_UINT64      = 4
        CPPTYPE_DOUBLE      = 5
        CPPTYPE_FLOAT       = 6
        CPPTYPE_BOOL        = 7
        CPPTYPE_ENUM        = 8
        CPPTYPE_STRING      = 9
        CPPTYPE_MESSAGE     = 10
    
    CPP_TYPE ={
        FDP.TYPE_DOUBLE         : CppType.CPPTYPE_DOUBLE,
        FDP.TYPE_FLOAT          : CppType.CPPTYPE_FLOAT,
        FDP.TYPE_INT64          : CppType.CPPTYPE_INT64,
        FDP.TYPE_UINT64         : CppType.CPPTYPE_UINT64,
        FDP.TYPE_INT32          : CppType.CPPTYPE_INT32,
        FDP.TYPE_FIXED64        : CppType.CPPTYPE_UINT64,
        FDP.TYPE_FIXED32        : CppType.CPPTYPE_UINT32,
        FDP.TYPE_BOOL           : CppType.CPPTYPE_BOOL,
        FDP.TYPE_STRING         : CppType.CPPTYPE_STRING,
        FDP.TYPE_MESSAGE        : CppType.CPPTYPE_MESSAGE,
        FDP.TYPE_BYTES          : CppType.CPPTYPE_STRING,
        FDP.TYPE_UINT32         : CppType.CPPTYPE_UINT32,
        FDP.TYPE_ENUM           : CppType.CPPTYPE_ENUM,
        FDP.TYPE_SFIXED32       : CppType.CPPTYPE_INT32,
        FDP.TYPE_SFIXED64       : CppType.CPPTYPE_INT64,
        FDP.TYPE_SINT32         : CppType.CPPTYPE_INT32,
        FDP.TYPE_SINT64         : CppType.CPPTYPE_INT64
    }
    
    def printerr(*args):
        sys.stderr.write(" ".join(args))
        sys.stderr.write("
    ")
        sys.stderr.flush()
    
    class TreeNode(object):
        def __init__(self, name, parent=None, filename=None, package=None):
            super(TreeNode, self).__init__()
            self.child = []
            self.parent = parent
            self.filename = filename
            self.package = package
            if parent:
                self.parent.add_child(self)
            self.name = name
    
        def add_child(self, child):
            self.child.append(child)
    
        def find_child(self, child_names):
            if child_names:
                for i in self.child:
                    if i.name == child_names[0]:
                        return i.find_child(child_names[1:])
                raise StandardError
            else:
                return self
    
        def get_child(self, child_name):
            for i in self.child:
                if i.name == child_name:
                    return i
            return None
    
        def get_path(self, end = None):
            pos = self
            out = []
            while pos and pos != end:
                out.append(pos.name)
                pos = pos.parent
            out.reverse()
            return '.'.join(out)
    
        def get_global_name(self):
            return self.get_path()
    
        def get_local_name(self):
            pos = self
            while pos.parent:
                pos = pos.parent
                if self.package and pos.name == self.package[-1]:
                    break
            return self.get_path(pos)
    
        def __str__(self):
            return self.to_string(0)
    
        def __repr__(self):
            return str(self)
    
        def to_string(self, indent = 0):
            return ' '*indent + '<TreeNode ' + self.name + '(
    ' + 
                    ','.join([i.to_string(indent + 4) for i in self.child]) + 
                    ' '*indent +')>
    '
    
    class Env(object):
        filename = None
        package = None
        extend = None
        descriptor = None
        message = None
        context = None
        register = None
        def __init__(self):
            self.message_tree = TreeNode('')
            self.scope = self.message_tree
    
        def get_global_name(self):
            return self.scope.get_global_name()
    
        def get_local_name(self):
            return self.scope.get_local_name()
    
        def get_ref_name(self, type_name):
            try:
                node = self.lookup_name(type_name)
            except:
                # if the child doesn't be founded, it must be in this file
                return type_name[len('.'.join(self.package)) + 2:]
            if node.filename != self.filename:
                return node.filename + '_pb.' + node.get_local_name()
            return node.get_local_name()
    
        def lookup_name(self, name):
            names = name.split('.')
            if names[0] == '':
                return self.message_tree.find_child(names[1:])
            else:
                return self.scope.parent.find_child(names)
    
        def enter_package(self, package):
            if not package:
                return self.message_tree
            names = package.split('.')
            pos = self.message_tree
            for i, name in enumerate(names):
                new_pos = pos.get_child(name)
                if new_pos:
                    pos = new_pos
                else:
                    return self._build_nodes(pos, names[i:])
            return pos
    
        def enter_file(self, filename, package):
            self.filename = filename
            self.package = package.split('.')
            self._init_field()
            self.scope = self.enter_package(package)
    
        def exit_file(self):
            self._init_field()
            self.filename = None
            self.package = []
            self.scope = self.scope.parent
    
        def enter(self, message_name):
            self.scope = TreeNode(message_name, self.scope, self.filename,
                                  self.package)
    
        def exit(self):
            self.scope = self.scope.parent
    
        def _init_field(self):
            self.descriptor = []
            self.context = []
            self.message = []
            self.register = []
    
        def _build_nodes(self, node, names):
            parent = node
            for i in names:
                parent = TreeNode(i, parent, self.filename, self.package)
            return parent
    
    class Writer(object):
        def __init__(self, prefix=None):
            self.io = StringIO()
            self.__indent = ''
            self.__prefix = prefix
    
        def getvalue(self):
            return self.io.getvalue()
    
        def __enter__(self):
            self.__indent += '    '
            return self
    
        def __exit__(self, type, value, trackback):
            self.__indent = self.__indent[:-4]
    
        def __call__(self, data):
            self.io.write(self.__indent)
            if self.__prefix:
                self.io.write(self.__prefix)
            self.io.write(data)
    
    DEFAULT_VALUE = {
        FDP.TYPE_DOUBLE         : '0.0',
        FDP.TYPE_FLOAT          : '0.0',
        FDP.TYPE_INT64          : '0',
        FDP.TYPE_UINT64         : '0',
        FDP.TYPE_INT32          : '0',
        FDP.TYPE_FIXED64        : '0',
        FDP.TYPE_FIXED32        : '0',
        FDP.TYPE_BOOL           : 'false',
        FDP.TYPE_STRING         : '""',
        FDP.TYPE_MESSAGE        : 'nil',
        FDP.TYPE_BYTES          : '""',
        FDP.TYPE_UINT32         : '0',
        FDP.TYPE_ENUM           : '1',
        FDP.TYPE_SFIXED32       : '0',
        FDP.TYPE_SFIXED64       : '0',
        FDP.TYPE_SINT32         : '0',
        FDP.TYPE_SINT64         : '0',
    }
    
    def code_gen_enum_item(index, enum_value, env):
        full_name = env.get_local_name() + '.' + enum_value.name
        obj_name = full_name.upper().replace('.', '_') + '_ENUM'
        env.descriptor.append(
            "local %s = protobuf.EnumValueDescriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
        context('.name = "%s"
    ' % enum_value.name)
        context('.index = %d
    ' % index)
        context('.number = %d
    ' % enum_value.number)
    
        env.context.append(context.getvalue())
        return obj_name
    
    def code_gen_enum(enum_desc, env):
        env.enter(enum_desc.name)
        full_name = env.get_local_name()
        obj_name = full_name.upper().replace('.', '_')
        env.descriptor.append(
            "local %s = protobuf.EnumDescriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
        context('.name = "%s"
    ' % enum_desc.name)
        context('.full_name = "%s"
    ' % env.get_global_name())
    
        values = []
        for i, enum_value in enumerate(enum_desc.value):
            values.append(code_gen_enum_item(i, enum_value, env))
        context('.values = {%s}
    ' % ','.join(values))
    
        env.context.append(context.getvalue())
        env.exit()
        return obj_name
    
    def code_gen_field(index, field_desc, env):
        full_name = env.get_local_name() + '.' + field_desc.name
        obj_name = full_name.upper().replace('.', '_') + '_FIELD'
        env.descriptor.append(
            "local %s = protobuf.FieldDescriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
    
        context('.name = "%s"
    ' % field_desc.name)
        context('.full_name = "%s"
    ' % (
            env.get_global_name() + '.' + field_desc.name))
        context('.number = %d
    ' % field_desc.number)
        context('.index = %d
    ' % index)
        context('.label = %d
    ' % field_desc.label)
    
        if field_desc.HasField("default_value"):
            context('.has_default_value = true
    ')
            value = field_desc.default_value
            if field_desc.type == FDP.TYPE_STRING:
                context('.default_value = "%s"
    '%value)
            else:
                context('.default_value = %s
    '%value)
        else:
            context('.has_default_value = false
    ')
            if field_desc.label == FDP.LABEL_REPEATED:
                default_value = "{}"
            elif field_desc.HasField('type_name'):
                default_value = "nil"
            else:
                default_value = DEFAULT_VALUE[field_desc.type]
            context('.default_value = %s
    ' % default_value)
    
        if field_desc.HasField('type_name'):
            type_name = env.get_ref_name(field_desc.type_name).upper()
            if field_desc.type == FDP.TYPE_MESSAGE:
                context('.message_type = %s
    ' % type_name)
            else:
                context('.enum_type = %s
    ' % type_name)
    
        if field_desc.HasField('extendee'):
            type_name = env.get_ref_name(field_desc.extendee)
            env.register.append(
                "%s.RegisterExtension(%s)
    " % (type_name, obj_name)
            )
    
        context('.type = %d
    ' % field_desc.type)
        context('.cpp_type = %d
    
    ' % CPP_TYPE[field_desc.type])
        env.context.append(context.getvalue())
        return obj_name
    
    def code_gen_message(message_descriptor, env, containing_type = None):
        env.enter(message_descriptor.name)
        full_name = env.get_local_name()
        obj_name = full_name.upper().replace('.', '_')
        env.descriptor.append(
            "%s = protobuf.Descriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
        context('.name = "%s"
    ' % message_descriptor.name)
        context('.full_name = "%s"
    ' % env.get_global_name())
    
        nested_types = []
        for msg_desc in message_descriptor.nested_type:
            msg_name = code_gen_message(msg_desc, env, obj_name)
            nested_types.append(msg_name)
        context('.nested_types = {%s}
    ' % ', '.join(nested_types))
    
        enums = []
        for enum_desc in message_descriptor.enum_type:
            enums.append(code_gen_enum(enum_desc, env))
        context('.enum_types = {%s}
    ' % ', '.join(enums))
    
        fields = []
        for i, field_desc in enumerate(message_descriptor.field):
            fields.append(code_gen_field(i, field_desc, env))
    
        context('.fields = {%s}
    ' % ', '.join(fields))
        if len(message_descriptor.extension_range) > 0:
            context('.is_extendable = true
    ')
        else:
            context('.is_extendable = false
    ')
    
        extensions = []
        for i, field_desc in enumerate(message_descriptor.extension):
            extensions.append(code_gen_field(i, field_desc, env))
        context('.extensions = {%s}
    ' % ', '.join(extensions))
    
        if containing_type:
            context('.containing_type = %s
    ' % containing_type)
    
        env.message.append('%s = protobuf.Message(%s)
    ' % (full_name,
                                                            obj_name))
    
        env.context.append(context.getvalue())
        env.exit()
        return obj_name
    
    def write_header(writer):
        writer("""-- Generated By protoc-gen-lua Do not Edit
    """)
    
    def code_gen_file(proto_file, env, is_gen):
        filename = path.splitext(proto_file.name)[0]
        env.enter_file(filename, proto_file.package)
    
        includes = []
        for f in proto_file.dependency:
            inc_file = path.splitext(f)[0]
            includes.append(inc_file)
    
    #    for field_desc in proto_file.extension:
    #        code_gen_extensions(field_desc, field_desc.name, env)
    
        for enum_desc in proto_file.enum_type:
            code_gen_enum(enum_desc, env)
            for enum_value in enum_desc.value:
                env.message.append('%s = %d
    ' % (enum_value.name,
                                                  enum_value.number))
    
        for msg_desc in proto_file.message_type:
            code_gen_message(msg_desc, env)
    
        if is_gen:
            lua = Writer()
            write_header(lua)
            lua('local protobuf = require "protobuf"
    ')
            for i in includes:
                lua('local %s_PB = require("%s_pb")
    ' % (i.upper(), i))
            lua("module('%s_pb')
    " % env.filename)
    
            lua('
    
    ')
            map(lua, env.descriptor)
            lua('
    ')
            map(lua, env.context)
            lua('
    ')
            env.message.sort()
            map(lua, env.message)
            lua('
    ')
            map(lua, env.register)
    
            _files[env.filename+ '_pb.lua'] = lua.getvalue()
        env.exit_file()
    
    def main():
        plugin_require_bin = sys.stdin.read()
        code_gen_req = plugin_pb2.CodeGeneratorRequest()
        code_gen_req.ParseFromString(plugin_require_bin)
    
        env = Env()
        for proto_file in code_gen_req.proto_file:
            code_gen_file(proto_file, env,
                    proto_file.name in code_gen_req.file_to_generate)
    
        code_generated = plugin_pb2.CodeGeneratorResponse()
        for k in  _files:
            file_desc = code_generated.file.add()
            file_desc.name = k
            file_desc.content = _files[k]
    
        sys.stdout.write(code_generated.SerializeToString())
    
    if __name__ == "__main__":
        main()
    

      修改protoc-gen-lua文件的内容为以上,即可。

      lua里使用的时候,复合字段是有值的,直接取即可

      local person = person_pb.Person()

      person.company.name = "xxx" --其中company为复合字段,也就是另一个类型,比如 company_pb.Company()

      直接用就可以了

  • 相关阅读:
    jquery插件:web2.0分格的分页脚,可用于ajax无刷新分页
    Application共享数据
    WebClient类
    HttpResponse类
    IEqualityComparer<T>接口
    物理数据库设计 理解浮点数
    Server对象,HttpServerUtility类,获取服务器信息
    Linq to OBJECT之非延时标准查询操作符
    IComparer<T> 接口Linq比较接口
    会话状态Session
  • 原文地址:https://www.cnblogs.com/wanghe/p/4995579.html
Copyright © 2011-2022 走看看