#! /usr/bin/python

import sys
import os.path
import re

macros = {}

token = None
line = ""
idRe = "[a-zA-Z_][a-zA-Z_0-9]*"
tokenRe = "#?" + idRe + "|[0-9]+|."
inComment = False
inDirective = False

def getLine():
    global line
    global lineNumber
    line = inputFile.readline()
    lineNumber += 1
    if line == "":
        fatal("unexpected end of input")

def getToken():
    global token
    global line
    global inComment
    global inDirective
    while True:
        line = line.lstrip()
        if line != "":
            if line.startswith("/*"):
                inComment = True
                line = line[2:]
            elif inComment:
                commentEnd = line.find("*/")
                if commentEnd < 0:
                    line = ""
                else:
                    inComment = False
                    line = line[commentEnd + 2:]
            else:
                match = re.match(tokenRe, line)
                token = match.group(0)
                line = line[len(token):]
                if token.startswith('#'):
                    inDirective = True
                elif token in macros and not inDirective:
                    line = macros[token] + line
                    continue
                return True
        elif inDirective:
            token = "$"
            inDirective = False
            return True
        else:
            global lineNumber
            line = inputFile.readline()
            lineNumber += 1
            while line.endswith("\\\n"):
                line = line[:-2] + inputFile.readline()
                lineNumber += 1
            if line == "":
                if token == None:
                    fatal("unexpected end of input")
                token = None
                return False

n_errors = 0
def error(msg):
    global n_errors
    sys.stderr.write("%s:%d: %s\n" % (fileName, lineNumber, msg))
    n_errors += 1

def fatal(msg):
    error(msg)
    sys.exit(1)

def skipDirective():
    getToken()
    while token != '$':
        getToken()

def isId(s):
    return re.match(idRe + "$", s) != None

def forceId():
    if not isId(token):
        fatal("identifier expected")

def forceInteger():
    if not re.match('[0-9]+$', token):
        fatal("integer expected")

def match(t):
    if token == t:
        getToken()
        return True
    else:
        return False

def forceMatch(t):
    if not match(t):
        fatal("%s expected" % t)

def parseTaggedName():
    assert token in ('struct', 'union')
    name = token
    getToken()
    forceId()
    name = "%s %s" % (name, token)
    getToken()
    return name

def print_enum(tag, constants, storage_class):
    print """
%(storage_class)sconst char *
%(tag)s_to_string(uint16_t value)
{
    switch (value) {\
""" % {"tag": tag,
       "bufferlen": len(tag) + 32,
       "storage_class": storage_class}
    for constant in constants:
        print "    case %s: return \"%s\";" % (constant, constant)
    print """\
    }
    return NULL;
}\
""" % {"tag": tag}

def usage():
    argv0 = os.path.basename(sys.argv[0])
    print '''\
%(argv0)s, for extracting OpenFlow error codes from header files
usage: %(argv0)s FILE [FILE...]

This program reads the header files specified on the command line and
outputs a C source file for translating OpenFlow error codes into
strings, for use as lib/ofp-errors.c in the Open vSwitch source tree.

This program is specialized for reading lib/ofp-errors.h.  It will not
work on arbitrary header files without extensions.\
''' % {"argv0": argv0}
    sys.exit(0)

def extract_ofp_errors(filenames):
    error_types = {}

    comments = []
    names = []
    domain = {}
    reverse = {}
    for domain_name in ("OF1.0", "OF1.1", "OF1.2", "NX1.0", "NX1.1"):
        domain[domain_name] = {}
        reverse[domain_name] = {}

    n_errors = 0
    expected_errors = {}

    global fileName
    for fileName in filenames:
        global inputFile
        global lineNumber
        inputFile = open(fileName)
        lineNumber = 0

        while True:
            getLine()
            if re.match('enum ofperr', line):
                break

        while True:
            getLine()
            if line.startswith('/*') or not line or line.isspace():
                continue
            elif re.match('}', line):
                break

            if not line.lstrip().startswith('/*'):
                fatal("unexpected syntax between errors")

            comment = line.lstrip()[2:].strip()
            while not comment.endswith('*/'):
                getLine()
                if line.startswith('/*') or not line or line.isspace():
                    fatal("unexpected syntax within error")
                comment += ' %s' % line.lstrip('* \t').rstrip(' \t\r\n')
            comment = comment[:-2].rstrip()

            m = re.match('Expected: (.*)\.$', comment)
            if m:
                expected_errors[m.group(1)] = (fileName, lineNumber)
                continue

            m = re.match('((?:.(?!\.  ))+.)\.  (.*)$', comment)
            if not m:
                fatal("unexpected syntax between errors")

            dsts, comment = m.groups()

            getLine()
            m = re.match('\s+(?:OFPERR_((?:OFP|NX)[A-Z0-9_]+))(\s*=\s*OFPERR_OFS)?,',
                         line)
            if not m:
                fatal("syntax error expecting enum value")

            enum = m.group(1)

            comments.append(re.sub('\[[^]]*\]', '', comment))
            names.append(enum)

            for dst in dsts.split(', '):
                m = re.match(r'([A-Z0-9.+]+)\((\d+|(0x)[0-9a-fA-F]+)(?:,(\d+))?\)$', dst)
                if not m:
                    fatal("%s: syntax error in destination" % dst)
                targets = m.group(1)
                if m.group(3):
                    base = 16
                else:
                    base = 10
                type_ = int(m.group(2), base)
                if m.group(4):
                    code = int(m.group(4))
                else:
                    code = None

                target_map = {"OF1.0+": ("OF1.0", "OF1.1", "OF1.2"),
                              "OF1.1+": ("OF1.1", "OF1.2"),
                              "OF1.2+": ("OF1.2",),
                              "OF1.0":  ("OF1.0",),
                              "OF1.1":  ("OF1.1",),
                              "OF1.2":  ("OF1.2",),
                              "NX1.0+": ("OF1.0", "OF1.1", "OF1.2"),
                              "NX1.0":  ("OF1.0",),
                              "NX1.1":  ("OF1.1",),
                              "NX1.1+": ("OF1.1",),
                              "NX1.2":  ("OF1.2",)}
                if targets not in target_map:
                    fatal("%s: unknown error domain" % targets)
                if targets.startswith('NX') and code < 0x100:
                    fatal("%s: NX domain code cannot be less than 0x100" % dst)
                if targets.startswith('OF') and code >= 0x100:
                    fatal("%s: OF domain code cannot be greater than 0x100"
                          % dst)
                for target in target_map[targets]:
                    domain[target].setdefault(type_, {})
                    if code in domain[target][type_]:
                        msg = "%d,%d in %s means both %s and %s" % (
                            type_, code, target,
                            domain[target][type_][code][0], enum)
                        if msg in expected_errors:
                            del expected_errors[msg]
                        else:
                            error("%s: %s." % (dst, msg))
                            sys.stderr.write("%s:%d: %s: Here is the location "
                                             "of the previous definition.\n"
                                             % (domain[target][type_][code][1],
                                                domain[target][type_][code][2],
                                                dst))
                    else:
                        domain[target][type_][code] = (enum, fileName,
                                                       lineNumber)

                    if enum in reverse[target]:
                        error("%s: %s in %s means both %d,%d and %d,%d." %
                              (dst, enum, target,
                               reverse[target][enum][0],
                               reverse[target][enum][1],
                               type_, code))
                    reverse[target][enum] = (type_, code)

        inputFile.close()

    for fn, ln in expected_errors.itervalues():
        sys.stderr.write("%s:%d: expected duplicate not used.\n" % (fn, ln))
        n_errors += 1

    if n_errors:
        sys.exit(1)

    print """\
/* Generated automatically; do not modify!     -*- buffer-read-only: t -*- */

#define OFPERR_N_ERRORS %d

struct ofperr_domain {
    const char *name;
    uint8_t version;
    enum ofperr (*decode)(uint16_t type, uint16_t code);
    enum ofperr (*decode_type)(uint16_t type);
    struct pair errors[OFPERR_N_ERRORS];
};

static const char *error_names[OFPERR_N_ERRORS] = {
%s
};

static const char *error_comments[OFPERR_N_ERRORS] = {
%s
};\
""" % (len(names),
       '\n'.join('    "%s",' % name for name in names),
       '\n'.join('    "%s",' % re.sub(r'(["\\])', r'\\\1', comment)
                 for comment in comments))

    def output_domain(map, name, description, version):
        print """
static enum ofperr
%s_decode(uint16_t type, uint16_t code)
{
    switch ((type << 16) | code) {""" % name
        found = set()
        for enum in names:
            if enum not in map:
                continue
            type_, code = map[enum]
            if code is None:
                continue
            value = (type_ << 16) | code
            if value in found:
                continue
            found.add(value)
            print "    case (%d << 16) | %d:" % (type_, code)
            print "        return OFPERR_%s;" % enum
        print """\
    }

    return 0;
}

static enum ofperr
%s_decode_type(uint16_t type)
{
    switch (type) {""" % name
        for enum in names:
            if enum not in map:
                continue
            type_, code = map[enum]
            if code is not None:
                continue
            print "    case %d:" % type_
            print "        return OFPERR_%s;" % enum
        print """\
    }

    return 0;
}"""

        print """
static const struct ofperr_domain %s = {
    "%s",
    %d,
    %s_decode,
    %s_decode_type,
    {""" % (name, description, version, name, name)
        for enum in names:
            if enum in map:
                type_, code = map[enum]
                if code == None:
                    code = -1
            else:
                type_ = code = -1
            print "        { %2d, %3d }, /* %s */" % (type_, code, enum)
        print """\
    },
};"""

    output_domain(reverse["OF1.0"], "ofperr_of10", "OpenFlow 1.0", 0x01)
    output_domain(reverse["OF1.1"], "ofperr_of11", "OpenFlow 1.1", 0x02)
    output_domain(reverse["OF1.2"], "ofperr_of12", "OpenFlow 1.2", 0x03)

if __name__ == '__main__':
    if '--help' in sys.argv:
        usage()
    elif len(sys.argv) < 2:
        sys.stderr.write("at least one non-option argument required; "
                         "use --help for help\n")
        sys.exit(1)
    else:
        extract_ofp_errors(sys.argv[1:])
