zoukankan      html  css  js  c++  java
  • protoc-gen-lua支持嵌套类型

    #!/usr/bin/env python
    # -*- encoding:utf8 -*-
    # protoc-gen-erl
    # Google's Protocol Buffers project, ported to lua.
    # https://code.google.com/p/protoc-gen-lua/
    #
    # Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
    # All rights reserved.
    #
    # Use, modification and distribution are subject to the "New BSD License"
    # as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
    
    import sys
    import os.path as path
    from cStringIO import StringIO
    
    import plugin_pb2
    import google.protobuf.descriptor_pb2 as descriptor_pb2
    
    _packages = {}
    _files = {}
    _message = {}
    
    FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto
    
    if sys.platform == "win32":
        import msvcrt, os
        msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
        msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
    
    class CppType:
        CPPTYPE_INT32       = 1
        CPPTYPE_INT64       = 2
        CPPTYPE_UINT32      = 3
        CPPTYPE_UINT64      = 4
        CPPTYPE_DOUBLE      = 5
        CPPTYPE_FLOAT       = 6
        CPPTYPE_BOOL        = 7
        CPPTYPE_ENUM        = 8
        CPPTYPE_STRING      = 9
        CPPTYPE_MESSAGE     = 10
    
    CPP_TYPE ={
        FDP.TYPE_DOUBLE         : CppType.CPPTYPE_DOUBLE,
        FDP.TYPE_FLOAT          : CppType.CPPTYPE_FLOAT,
        FDP.TYPE_INT64          : CppType.CPPTYPE_INT64,
        FDP.TYPE_UINT64         : CppType.CPPTYPE_UINT64,
        FDP.TYPE_INT32          : CppType.CPPTYPE_INT32,
        FDP.TYPE_FIXED64        : CppType.CPPTYPE_UINT64,
        FDP.TYPE_FIXED32        : CppType.CPPTYPE_UINT32,
        FDP.TYPE_BOOL           : CppType.CPPTYPE_BOOL,
        FDP.TYPE_STRING         : CppType.CPPTYPE_STRING,
        FDP.TYPE_MESSAGE        : CppType.CPPTYPE_MESSAGE,
        FDP.TYPE_BYTES          : CppType.CPPTYPE_STRING,
        FDP.TYPE_UINT32         : CppType.CPPTYPE_UINT32,
        FDP.TYPE_ENUM           : CppType.CPPTYPE_ENUM,
        FDP.TYPE_SFIXED32       : CppType.CPPTYPE_INT32,
        FDP.TYPE_SFIXED64       : CppType.CPPTYPE_INT64,
        FDP.TYPE_SINT32         : CppType.CPPTYPE_INT32,
        FDP.TYPE_SINT64         : CppType.CPPTYPE_INT64
    }
    
    def printerr(*args):
        sys.stderr.write(" ".join(args))
        sys.stderr.write("
    ")
        sys.stderr.flush()
    
    class TreeNode(object):
        def __init__(self, name, parent=None, filename=None, package=None):
            super(TreeNode, self).__init__()
            self.child = []
            self.parent = parent
            self.filename = filename
            self.package = package
            if parent:
                self.parent.add_child(self)
            self.name = name
    
        def add_child(self, child):
            self.child.append(child)
    
        def find_child(self, child_names):
            if child_names:
                for i in self.child:
                    if i.name == child_names[0]:
                        return i.find_child(child_names[1:])
                raise StandardError
            else:
                return self
    
        def get_child(self, child_name):
            for i in self.child:
                if i.name == child_name:
                    return i
            return None
    
        def get_path(self, end = None):
            pos = self
            out = []
            while pos and pos != end:
                out.append(pos.name)
                pos = pos.parent
            out.reverse()
            return '.'.join(out)
    
        def get_global_name(self):
            return self.get_path()
    
        def get_local_name(self):
            pos = self
            while pos.parent:
                pos = pos.parent
                if self.package and pos.name == self.package[-1]:
                    break
            return self.get_path(pos)
    
        def __str__(self):
            return self.to_string(0)
    
        def __repr__(self):
            return str(self)
    
        def to_string(self, indent = 0):
            return ' '*indent + '<TreeNode ' + self.name + '(
    ' + 
                    ','.join([i.to_string(indent + 4) for i in self.child]) + 
                    ' '*indent +')>
    '
    
    class Env(object):
        filename = None
        package = None
        extend = None
        descriptor = None
        message = None
        context = None
        register = None
        def __init__(self):
            self.message_tree = TreeNode('')
            self.scope = self.message_tree
    
        def get_global_name(self):
            return self.scope.get_global_name()
    
        def get_local_name(self):
            return self.scope.get_local_name()
    
        def get_ref_name(self, type_name):
            try:
                node = self.lookup_name(type_name)
            except:
                # if the child doesn't be founded, it must be in this file
                return type_name[len('.'.join(self.package)) + 2:]
            if node.filename != self.filename:
                return node.filename + '_pb.' + node.get_local_name()
            return node.get_local_name()
    
        def lookup_name(self, name):
            names = name.split('.')
            if names[0] == '':
                return self.message_tree.find_child(names[1:])
            else:
                return self.scope.parent.find_child(names)
    
        def enter_package(self, package):
            if not package:
                return self.message_tree
            names = package.split('.')
            pos = self.message_tree
            for i, name in enumerate(names):
                new_pos = pos.get_child(name)
                if new_pos:
                    pos = new_pos
                else:
                    return self._build_nodes(pos, names[i:])
            return pos
    
        def enter_file(self, filename, package):
            self.filename = filename
            self.package = package.split('.')
            self._init_field()
            self.scope = self.enter_package(package)
    
        def exit_file(self):
            self._init_field()
            self.filename = None
            self.package = []
            self.scope = self.scope.parent
    
        def enter(self, message_name):
            self.scope = TreeNode(message_name, self.scope, self.filename,
                                  self.package)
    
        def exit(self):
            self.scope = self.scope.parent
    
        def _init_field(self):
            self.descriptor = []
            self.context = []
            self.message = []
            self.register = []
    
        def _build_nodes(self, node, names):
            parent = node
            for i in names:
                parent = TreeNode(i, parent, self.filename, self.package)
            return parent
    
    class Writer(object):
        def __init__(self, prefix=None):
            self.io = StringIO()
            self.__indent = ''
            self.__prefix = prefix
    
        def getvalue(self):
            return self.io.getvalue()
    
        def __enter__(self):
            self.__indent += '    '
            return self
    
        def __exit__(self, type, value, trackback):
            self.__indent = self.__indent[:-4]
    
        def __call__(self, data):
            self.io.write(self.__indent)
            if self.__prefix:
                self.io.write(self.__prefix)
            self.io.write(data)
    
    DEFAULT_VALUE = {
        FDP.TYPE_DOUBLE         : '0.0',
        FDP.TYPE_FLOAT          : '0.0',
        FDP.TYPE_INT64          : '0',
        FDP.TYPE_UINT64         : '0',
        FDP.TYPE_INT32          : '0',
        FDP.TYPE_FIXED64        : '0',
        FDP.TYPE_FIXED32        : '0',
        FDP.TYPE_BOOL           : 'false',
        FDP.TYPE_STRING         : '""',
        FDP.TYPE_MESSAGE        : 'nil',
        FDP.TYPE_BYTES          : '""',
        FDP.TYPE_UINT32         : '0',
        FDP.TYPE_ENUM           : '1',
        FDP.TYPE_SFIXED32       : '0',
        FDP.TYPE_SFIXED64       : '0',
        FDP.TYPE_SINT32         : '0',
        FDP.TYPE_SINT64         : '0',
    }
    
    def code_gen_enum_item(index, enum_value, env):
        full_name = env.get_local_name() + '.' + enum_value.name
        obj_name = full_name.upper().replace('.', '_') + '_ENUM'
        env.descriptor.append(
            "local %s = protobuf.EnumValueDescriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
        context('.name = "%s"
    ' % enum_value.name)
        context('.index = %d
    ' % index)
        context('.number = %d
    ' % enum_value.number)
    
        env.context.append(context.getvalue())
        return obj_name
    
    def code_gen_enum(enum_desc, env):
        env.enter(enum_desc.name)
        full_name = env.get_local_name()
        obj_name = full_name.upper().replace('.', '_')
        env.descriptor.append(
            "local %s = protobuf.EnumDescriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
        context('.name = "%s"
    ' % enum_desc.name)
        context('.full_name = "%s"
    ' % env.get_global_name())
    
        values = []
        for i, enum_value in enumerate(enum_desc.value):
            values.append(code_gen_enum_item(i, enum_value, env))
        context('.values = {%s}
    ' % ','.join(values))
    
        env.context.append(context.getvalue())
        env.exit()
        return obj_name
    
    def code_gen_field(index, field_desc, env):
        full_name = env.get_local_name() + '.' + field_desc.name
        obj_name = full_name.upper().replace('.', '_') + '_FIELD'
        env.descriptor.append(
            "local %s = protobuf.FieldDescriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
    
        context('.name = "%s"
    ' % field_desc.name)
        context('.full_name = "%s"
    ' % (
            env.get_global_name() + '.' + field_desc.name))
        context('.number = %d
    ' % field_desc.number)
        context('.index = %d
    ' % index)
        context('.label = %d
    ' % field_desc.label)
    
        if field_desc.HasField("default_value"):
            context('.has_default_value = true
    ')
            value = field_desc.default_value
            if field_desc.type == FDP.TYPE_STRING:
                context('.default_value = "%s"
    '%value)
            else:
                context('.default_value = %s
    '%value)
        else:
            context('.has_default_value = false
    ')
            if field_desc.label == FDP.LABEL_REPEATED:
                default_value = "{}"
            elif field_desc.HasField('type_name'):
                default_value = "nil"
            else:
                default_value = DEFAULT_VALUE[field_desc.type]
            context('.default_value = %s
    ' % default_value)
    
        if field_desc.HasField('type_name'):
            type_name = env.get_ref_name(field_desc.type_name).upper()
            if field_desc.type == FDP.TYPE_MESSAGE:
                context('.message_type = %s
    ' % type_name)
            else:
                context('.enum_type = %s
    ' % type_name)
    
        if field_desc.HasField('extendee'):
            type_name = env.get_ref_name(field_desc.extendee)
            env.register.append(
                "%s.RegisterExtension(%s)
    " % (type_name, obj_name)
            )
    
        context('.type = %d
    ' % field_desc.type)
        context('.cpp_type = %d
    
    ' % CPP_TYPE[field_desc.type])
        env.context.append(context.getvalue())
        return obj_name
    
    def code_gen_message(message_descriptor, env, containing_type = None):
        env.enter(message_descriptor.name)
        full_name = env.get_local_name()
        obj_name = full_name.upper().replace('.', '_')
        env.descriptor.append(
            "%s = protobuf.Descriptor();
    "% obj_name
        )
    
        context = Writer(obj_name)
        context('.name = "%s"
    ' % message_descriptor.name)
        context('.full_name = "%s"
    ' % env.get_global_name())
    
        nested_types = []
        for msg_desc in message_descriptor.nested_type:
            msg_name = code_gen_message(msg_desc, env, obj_name)
            nested_types.append(msg_name)
        context('.nested_types = {%s}
    ' % ', '.join(nested_types))
    
        enums = []
        for enum_desc in message_descriptor.enum_type:
            enums.append(code_gen_enum(enum_desc, env))
        context('.enum_types = {%s}
    ' % ', '.join(enums))
    
        fields = []
        for i, field_desc in enumerate(message_descriptor.field):
            fields.append(code_gen_field(i, field_desc, env))
    
        context('.fields = {%s}
    ' % ', '.join(fields))
        if len(message_descriptor.extension_range) > 0:
            context('.is_extendable = true
    ')
        else:
            context('.is_extendable = false
    ')
    
        extensions = []
        for i, field_desc in enumerate(message_descriptor.extension):
            extensions.append(code_gen_field(i, field_desc, env))
        context('.extensions = {%s}
    ' % ', '.join(extensions))
    
        if containing_type:
            context('.containing_type = %s
    ' % containing_type)
    
        env.message.append('%s = protobuf.Message(%s)
    ' % (full_name,
                                                            obj_name))
    
        env.context.append(context.getvalue())
        env.exit()
        return obj_name
    
    def write_header(writer):
        writer("""-- Generated By protoc-gen-lua Do not Edit
    """)
    
    def code_gen_file(proto_file, env, is_gen):
        filename = path.splitext(proto_file.name)[0]
        env.enter_file(filename, proto_file.package)
    
        includes = []
        for f in proto_file.dependency:
            inc_file = path.splitext(f)[0]
            includes.append(inc_file)
    
    #    for field_desc in proto_file.extension:
    #        code_gen_extensions(field_desc, field_desc.name, env)
    
        for enum_desc in proto_file.enum_type:
            code_gen_enum(enum_desc, env)
            for enum_value in enum_desc.value:
                env.message.append('%s = %d
    ' % (enum_value.name,
                                                  enum_value.number))
    
        for msg_desc in proto_file.message_type:
            code_gen_message(msg_desc, env)
    
        if is_gen:
            lua = Writer()
            write_header(lua)
            lua('local protobuf = require "protobuf"
    ')
            for i in includes:
                lua('local %s_PB = require("%s_pb")
    ' % (i.upper(), i))
            lua("module('%s_pb')
    " % env.filename)
    
            lua('
    
    ')
            map(lua, env.descriptor)
            lua('
    ')
            map(lua, env.context)
            lua('
    ')
            env.message.sort()
            map(lua, env.message)
            lua('
    ')
            map(lua, env.register)
    
            _files[env.filename+ '_pb.lua'] = lua.getvalue()
        env.exit_file()
    
    def main():
        plugin_require_bin = sys.stdin.read()
        code_gen_req = plugin_pb2.CodeGeneratorRequest()
        code_gen_req.ParseFromString(plugin_require_bin)
    
        env = Env()
        for proto_file in code_gen_req.proto_file:
            code_gen_file(proto_file, env,
                    proto_file.name in code_gen_req.file_to_generate)
    
        code_generated = plugin_pb2.CodeGeneratorResponse()
        for k in  _files:
            file_desc = code_generated.file.add()
            file_desc.name = k
            file_desc.content = _files[k]
    
        sys.stdout.write(code_generated.SerializeToString())
    
    if __name__ == "__main__":
        main()
    

      修改protoc-gen-lua文件的内容为以上,即可。

      lua里使用的时候,复合字段是有值的,直接取即可

      local person = person_pb.Person()

      person.company.name = "xxx" --其中company为复合字段,也就是另一个类型,比如 company_pb.Company()

      直接用就可以了

  • 相关阅读:
    poj3278 Catch That Cow
    poj2251 Dungeon Master
    poj1321 棋盘问题
    poj3083 Children of the Candy Cor
    jvm基础知识—垃圾回收机制
    jvm基础知识1
    java面试基础必备
    java soket通信总结 bio nio aio的区别和总结
    java scoket aIO 通信
    java scoket Blocking 阻塞IO socket通信四
  • 原文地址:https://www.cnblogs.com/wanghe/p/4995579.html
Copyright © 2011-2022 走看看