个板马,喵莫比喵?

通过AST语法树解析来实现一个简易的表达式求值器

最近工作中需要搞一个简易的规则引擎,在翻了Github之后发现了danthedeckie/simpleeval这个库,用起来挺方便的,不过在查阅源码后发现,simpleeval原理是复用Python的ast库来解析计算的,依赖于Python语法,通过构建一个安全沙箱来隔离运行的。
不过在工作业务中,考虑到后续极有可能有自定义语法的需求,所以想来想去,还是写一个简单的解析器比较合适。

代码将基于Python实现。

业务需求

求值器需要支持以下能力:

  1. 大小比较:整数、小数、负数的大小比较。
  2. 布尔运算andor 、... 这种逻辑运算。
  3. 字符串比较:对字符串的“==”比较
  4. 运算优先级:需支持括号优先计算。
  5. 变量传入:需支持传入变量代入计算。

整体架构

整个求值器分为三个模块:

  1. 词法解析器(Tokenizer):将 expr 表达式字符串拆解成 Token 序列。
  2. 语法解析器(Parser):根据 Token 序列构建 AST。
  3. 求值器(Evaluator):递归遍历 AST,结合变量上下文 (names) 计算出最终布尔值。
expr字符串 → Token列表 → AST → 求值结果

词法解析

词法解析器的输入是原始表达式 expr 字符串,输出是带有类型和值的 Token 列表。根据业务需求,在这里需要识别:

  • 关键字andornotTrueFalse
  • 变量名:以字母或下划线开头,后跟字母、数字或下划线
  • 数字:支持整数、浮点数、科学计数法,以及合法的负数
  • 字符串:以双引号""包裹的字符串
  • 运算符>>=<<===
  • 括号()

其中,负数的处理需要特殊判断,避免与减号混淆(不过目前还不支持减法运算,因此单独的负号仅在表达式开头或运算符后出现时才被解析为数字的一部分)。

核心部分代码:

def tokenize(text: str) -> list[Token]:
    tokens = []
    i = 0
    while i < len(text):
        ch = text[i]
        if ch.isspace():
            i += 1
            continue
        # 识别标识符和关键字
        if ch.isalpha() or ch == '_':
            # ... 收集并判断关键字
        # 识别数字(含负数、科学计数法)
        if ch.isdigit() or (ch == '-' and 合法位置):
            # ... 收集数字字面量
        # 识别字符串
        if ch == '"':
            # ... 收集字符串字面量
        # 识别运算符、括号等
        # ...
    return tokens

词法解析器最终返回一个 Tokens 列表

语法解析

语法解析器采用递归下降方法,根据运算符优先级定义解析函数(优先级从低到高):

优先级运算符解析函数
最低or_parse_or
and_parse_and
not_parse_not
比较运算_parse_compare
最高括号 / 原子项_parse_term

每个解析函数会消费与之匹配的 Token,然后返回对应的 AST 节点。

AST 节点类型(使用 dataclass 定义):

  • Constant:布尔值、数字、字符串
  • Name:变量引用
  • BoolOpand / or 运算,包含多个子节点
  • UnaryOpnot 运算
  • Compare:两两比较运算(如 a > b

例如,表达式 not temperature >= 18.03 and state == "active" 生成的 AST 大概长下面这个样子:

BoolOp(op=AND)
├── UnaryOp(op=NOT)
│   └── Compare(left=Name('temperature'), op=GTE, right=Constant(18.03))
└── Compare(left=Name('state'), op=EQ, right=Constant("active"))

求值器

求值函数 eval_ast(node, names) 通过递归计算AST来,根据节点类型进行求值:

  • Constant:直接返回值
  • Name:从 names 字典中查找变量值,未定义则抛出 NameError
  • BoolOp:短路求值(and 遇到 False 立即返回,or 遇到 True 立即返回)
  • UnaryOp:对操作数取反
  • Compare:比较左右子节点的值,返回布尔结果

最后暴露一个便捷函数 evaluate_expression(text, names) 完成解析+求值。

测试

  • 基础测试:
    让AI写了一段测试代码和少量测试用例,测试输出结果如下:

expr_eval_基础测试1.png

  • 性能测试
    同样让AI写了一段与simpleeval对比的测试代码,上面一段是simpleeval,下面是我写的eval:

expr_eval_性能测试1.png

完整代码

from dataclasses import dataclass


class ParseError(ValueError):
    """解析错误异常类"""

    pass


class TokenType:
    """Token类型枚举"""

    NAME = "NAME"  # 变量名
    NUMBER = "NUMBER"  # 数字
    STRING = "STRING"  # 字符串
    BOOL = "BOOL"  # 布尔值
    AND = "AND"  # 逻辑与
    OR = "OR"  # 逻辑或
    NOT = "NOT"  # 逻辑非
    GT = ">"  # 大于
    GTE = ">="  # 大于等于
    LT = "<"  # 小于
    LTE = "<="  # 小于等于
    EQ = "=="  # 等于
    LPAREN = "("  # 左括号
    RPAREN = ")"  # 右括号
    EOF = "EOF"  # 结束标记


@dataclass(frozen=True)
class Token:
    """词法单元(Token)数据类

    Attributes:
        type: Token的类型(如NAME、NUMBER等)
        value: Token的值(如变量名、数字等)
        pos: Token在原始文本中的起始位置
    """

    type: str
    value: str | None
    pos: int


class Node:
    """AST节点基类"""

    pass


@dataclass(frozen=True)
class Name(Node):
    """变量名节点

    Attributes:
        id: 变量名称
    """

    id: str


@dataclass(frozen=True)
class Constant(Node):
    """常量节点

    Attributes:
        value: 常量值(bool、int、float、str)
    """

    value: object


@dataclass(frozen=True)
class BoolOp(Node):
    """布尔运算节点

    Attributes:
        op: 运算符类型(AND或OR)
        values: 操作数节点列表
    """

    op: str
    values: list[Node]


@dataclass(frozen=True)
class Compare(Node):
    """比较运算节点(仅支持两两比较)

    Attributes:
        left: 左操作数
        op: 比较运算符(>、>=、<、<=、==)
        right: 右操作数
    """

    left: Node
    op: str
    right: Node


@dataclass(frozen=True)
class UnaryOp(Node):
    """一元运算节点

    Attributes:
        op: 运算符类型(NOT)
        operand: 操作数节点
    """

    op: str
    operand: Node


def tokenize(text: str) -> list[Token]:
    """词法解析器:将输入的文本字符串转换为Token列表

    Args:
        text: 待解析的表达式字符串

    Returns:
        Token列表,包含EOF标记

    Raises:
        ParseError: 遇到无法识别的字符时抛出
    """
    tokens: list[Token] = []
    i = 0
    length = len(text)

    while i < length:
        ch = text[i]
        if ch.isspace():
            i += 1
            continue

        if ch.isalpha() or ch == "_":
            start = i
            i += 1
            while i < length and (text[i].isalnum() or text[i] == "_"):
                i += 1
            value = text[start:i]
            if value == "and":
                tokens.append(Token(TokenType.AND, value, start))
            elif value == "or":
                tokens.append(Token(TokenType.OR, value, start))
            elif value == "not":
                tokens.append(Token(TokenType.NOT, value, start))
            elif value == "True" or value == "False":
                tokens.append(Token(TokenType.BOOL, value, start))
            else:
                tokens.append(Token(TokenType.NAME, value, start))
            continue

        if ch.isdigit() or (ch == "-" and i + 1 < length and text[i + 1].isdigit()):
            # 检查负号是否合法(只在特定位置允许负数)
            if ch == "-":
                # 负号只在以下情况合法:开头、左括号后、运算符后
                if tokens and tokens[-1].type not in {
                    TokenType.LPAREN,
                    TokenType.AND,
                    TokenType.OR,
                    TokenType.NOT,
                    TokenType.GT,
                    TokenType.GTE,
                    TokenType.LT,
                    TokenType.LTE,
                    TokenType.EQ,
                }:
                    raise ParseError(f"Unexpected '-' at position {i}")

            start = i
            if ch == "-":
                i += 1  # 跳过负号

            # 支持整数部分
            while i < length and text[i].isdigit():
                i += 1

            # 支持小数点和小数部分
            if i < length and text[i] == ".":
                i += 1
                while i < length and text[i].isdigit():
                    i += 1

            # 支持科学计数法 (e.g., 1e10, 1.5e-10, 2E+5)
            if i < length and text[i] in ("e", "E"):
                i += 1
                # 可选的正负号
                if i < length and text[i] in ("+", "-"):
                    i += 1
                # 指数部分必须有数字
                if i >= length or not text[i].isdigit():
                    raise ParseError(f"Invalid scientific notation at position {start}")
                while i < length and text[i].isdigit():
                    i += 1

            tokens.append(Token(TokenType.NUMBER, text[start:i], start))
            continue

        # 解析字符串字面量(双引号)
        if ch == '"':
            start = i
            i += 1  # 跳过开始的双引号
            string_value = ""

            while i < length:
                if text[i] == '"':
                    # 找到结束的双引号
                    i += 1
                    tokens.append(Token(TokenType.STRING, string_value, start))
                    break
                elif text[i] == "\\":
                    # 支持转义字符
                    if i + 1 < length:
                        next_char = text[i + 1]
                        if next_char == "n":
                            string_value += "\n"
                        elif next_char == "t":
                            string_value += "\t"
                        elif next_char == "r":
                            string_value += "\r"
                        elif next_char == "\\":
                            string_value += "\\"
                        elif next_char == '"':
                            string_value += '"'
                        else:
                            string_value += next_char
                        i += 2
                    else:
                        raise ParseError(f"Incomplete escape sequence at position {i}")
                else:
                    string_value += text[i]
                    i += 1
            else:
                # while 循环正常结束,说明没有找到结束引号
                raise ParseError(f"Unterminated string starting at position {start}")
            continue

        if ch == ">":
            if i + 1 < length and text[i + 1] == "=":
                tokens.append(Token(TokenType.GTE, ">=", i))
                i += 2
            else:
                tokens.append(Token(TokenType.GT, ">", i))
                i += 1
            continue

        if ch == "<":
            if i + 1 < length and text[i + 1] == "=":
                tokens.append(Token(TokenType.LTE, "<=", i))
                i += 2
            else:
                tokens.append(Token(TokenType.LT, "<", i))
                i += 1
            continue

        if ch == "=":
            if i + 1 < length and text[i + 1] == "=":
                tokens.append(Token(TokenType.EQ, "==", i))
                i += 2
                continue
            raise ParseError(f"Unexpected '=' at position {i}")

        if ch == "(":
            tokens.append(Token(TokenType.LPAREN, "(", i))
            i += 1
            continue

        if ch == ")":
            tokens.append(Token(TokenType.RPAREN, ")", i))
            i += 1
            continue

        # 如果遇到单独的负号(后面不是数字),报错
        if ch == "-":
            raise ParseError(
                f"Unexpected '-' at position {i} (subtraction not supported)"
            )

        raise ParseError(f"Unexpected character '{ch}' at position {i}")

    tokens.append(Token(TokenType.EOF, None, length))
    return tokens


class Parser:
    """语法分析器:将Token列表转换为抽象语法树(AST)

    使用递归下降解析法,按运算符优先级解析:
    优先级从低到高:or -> and -> not -> 比较运算 -> 原子项
    """

    def __init__(self, tokens: list[Token]):
        self.tokens = tokens
        self.index = 0  # 当前解析位置

    def _current(self) -> Token:
        return self.tokens[self.index]

    def _advance(self) -> Token:
        token = self.tokens[self.index]
        self.index += 1
        return token

    def _expect(self, kind: str) -> Token:
        token = self._current()
        if token.type != kind:
            raise ParseError(f"Expected {kind}, got {token.type} at {token.pos}")
        return self._advance()

    def parse(self) -> Node:
        node = self._parse_or()
        if self._current().type != TokenType.EOF:
            token = self._current()
            raise ParseError(f"Unexpected token {token.type} at {token.pos}")
        return node

    def _parse_or(self) -> Node:
        """解析or运算(优先级最低)"""
        left = self._parse_and()
        values = [left]
        while self._current().type == TokenType.OR:
            self._advance()
            values.append(self._parse_and())
        if len(values) == 1:
            return left
        return BoolOp(TokenType.OR, values)

    def _parse_and(self) -> Node:
        """解析and运算"""
        left = self._parse_not()
        values = [left]
        while self._current().type == TokenType.AND:
            self._advance()
            values.append(self._parse_not())
        if len(values) == 1:
            return left
        return BoolOp(TokenType.AND, values)

    def _parse_not(self) -> Node:
        """解析not运算(支持连续not)"""
        if self._current().type == TokenType.NOT:
            self._advance()
            return UnaryOp(TokenType.NOT, self._parse_not())
        return self._parse_compare()

    def _parse_compare(self) -> Node:
        """解析比较运算(仅支持两两比较,如 a < b)"""
        left = self._parse_term()
        if self._current().type in {
            TokenType.GT,
            TokenType.GTE,
            TokenType.LT,
            TokenType.LTE,
            TokenType.EQ,
        }:
            op = self._advance().type
            right = self._parse_term()
            return Compare(left, op, right)
        return left

    def _parse_term(self) -> Node:
        """解析基本项:括号表达式、变量名、布尔值、数字、字符串"""
        token = self._current()
        if token.type == TokenType.LPAREN:
            self._advance()
            node = self._parse_or()
            self._expect(TokenType.RPAREN)
            return node
        if token.type == TokenType.NAME:
            self._advance()
            return Name(token.value or "")
        if token.type == TokenType.BOOL:
            self._advance()
            return Constant(token.value == "True")
        if token.type == TokenType.NUMBER:
            self._advance()
            # 支持浮点数解析
            return Constant(float(token.value))
        if token.type == TokenType.STRING:
            self._advance()
            # 字符串字面量
            return Constant(token.value)
        raise ParseError(f"Unexpected token {token.type} at {token.pos}")


def parse_expression(text: str) -> Node:
    tokens = tokenize(text)
    parser = Parser(tokens)
    return parser.parse()


def eval_ast(node: Node, names: dict) -> bool:
    """递归求值抽象语法树

    Args:
        node: AST节点
        names: 变量名到值的映射字典

    Returns:
        节点的求值结果

    Raises:
        NameError: 引用了未定义的变量
        ParseError: 遇到不支持的节点或运算符
    """
    if isinstance(node, Constant):
        return node.value
    if isinstance(node, Name):
        if node.id not in names:
            raise NameError(f"Name '{node.id}' is not defined")
        return names[node.id]
    if isinstance(node, BoolOp):
        if node.op == TokenType.AND:
            for value in node.values:
                if not bool(eval_ast(value, names)):
                    return False
            return True
        if node.op == TokenType.OR:
            for value in node.values:
                if bool(eval_ast(value, names)):
                    return True
            return False
        raise ParseError(f"Unsupported boolean operator {node.op}")
    if isinstance(node, UnaryOp):
        if node.op == TokenType.NOT:
            return not bool(eval_ast(node.operand, names))
        raise ParseError(f"Unsupported unary operator {node.op}")
    if isinstance(node, Compare):
        # 两两比较:a < b
        left_val = eval_ast(node.left, names)
        right_val = eval_ast(node.right, names)
        if node.op == TokenType.GT:
            return left_val > right_val
        elif node.op == TokenType.GTE:
            return left_val >= right_val
        elif node.op == TokenType.LT:
            return left_val < right_val
        elif node.op == TokenType.LTE:
            return left_val <= right_val
        elif node.op == TokenType.EQ:
            return left_val == right_val
        else:
            raise ParseError(f"Unsupported comparison operator {node.op}")
    raise ParseError(f"Unsupported node type {type(node).__name__}")


def evaluate_expression(text: str, names: dict) -> bool:
    """便捷函数:解析并求值表达式

    Args:
        text: 表达式字符串
        names: 变量名到值的映射字典

    Returns:
        表达式的布尔求值结果
    """
    ast_root = parse_expression(text)
    return bool(eval_ast(ast_root, names))
通过AST语法树解析来实现一个简易的表达式求值器

https://blog.mapotofu.cn/archives/impl_simple_expr_eval_by_using_ast_parsing.html

作者

麻婆豆腐

发布时间

2026-04-07

添加新评论