本文的前置章节 hive高阶聚合,在这篇文章中详细介绍了 hive/spark sql 的高阶聚合在多层级聚合场景的应用。相信在带来效率提升的同时也会有一个烦恼,那就是 grouping__id 的计算。对于博主所在公司的中台同时存在 hive2.1 和 spark 3.0 两种引擎,上篇介绍到新老版本的 grouping__id 的计算逻辑是不一样的。因此本文提供一个使用 rust 开发的小工具用于快速计算(下文称 fast-gid 即:快速 grouping__id)

一、代码

项目依赖如下

[package]
name = "fast-gid"
version = "0.1.0"
edition = "2021"

[dependencies]
regex = "1"
prettytable-rs = "0.10.0"

1.1 确认新老版本逻辑

fast-gid 需要兼容两个版本的计算逻辑,因此需要显式地传入一个 flag

fn is_new_version() -> bool {
    // 接受一行命令行输入
    eprint!("是否启动新版本计算逻辑[y/n]: ");
    let mut is_version_str = String::new();
    match io::stdin().read_line(&mut is_version_str) {
        Ok(_) => {}
        Err(err) => {
            panic!("[{}] - {}", err.kind(), err)
        }
    }
    is_version_str = is_version_str.trim().to_lowercase();
    if is_version_str == "y" {
        return true;
    } else if is_version_str == "n" {
        return false;
    } else {
        panic!("期望 y or n but get {}", is_version_str)
    }
}

1.2 传入 group by

在进行 sql 开发时,往往伴随着代码格式化而被添加若干换行、制表符等,几乎所有的程序默认接收控制台输入的终止符都是 \n,因此在接收 sql 时需要特殊处理。众所周知 sql 的结束符是;,抽象出一个函数,因为在接收 grouping sets 时同样需要使用

fn read_with_char(end: &str) -> String {
    let mut result = String::new();
    loop {
        let mut line = String::new();
        match io::stdin().read_line(&mut line) {
            Ok(_) => {
                result.push_str(line.as_str());
                if line.trim().ends_with(end) {
                    result = result.replace(end, "");
                    break;
                }
            }
            Err(err) => {
                panic!("[{}] - {}", err.kind(), err)
            }
        }
    }
    return result;
}

下面是解析 group by 的逻辑,将每个字段存储数组中

fn parse_group_by() -> Vec<String> {
    eprint!("输入 group by 后的列表[以;结尾]: ");
    let group_by_str = read_with_char(";");
    let mut vec: Vec<String> = Vec::new();
    for x in group_by_str.split(",") {
        vec.push(x.trim().to_lowercase())
    }
    return vec;
}

1.3 传入 grouping sets

接收逻辑同上,但解析 grouping sets 略微复杂,

规定传入格式如下:

(a,b,c,d,e),
(a,b,c,d),
(a,b,c),
(a,b),
(a),
()

这里使用正则提取

fn parse_grouping_sets() -> Vec<Vec<String>> {
    println!("输入 grouping sets (*)的列表[以;结尾,按回车分割]: ");
    // 使用正则提取
    let re = Regex::new(r"\(([^)]*)\)").unwrap();
    let mut result = Vec::new();
    let grouping_sets = read_with_char(";");
    for x in grouping_sets.split("\n") {
        for cap in re.captures_iter(x) {
            let mut vec = Vec::new();
            for set in cap[1].split(",") {
                vec.push(set.trim().to_lowercase())
            }
            result.push(vec)
        }
    }
    return result;
}

1.4 计算 grouping__id

下面遍历 grouping sets 得到的层级列表,每个层级均是一个数组,根据 grouping__id 计算逻辑,每个层级都需要遍历一次 group by 列表去判断 group by 中每个字段在层级中是否存在并结合 is_new_version判断是追加 0 还是 1

fn computer_grouping_id(is_new_version: bool, group_by: Vec<String>, grouping_sets: Vec<Vec<String>>) {
    let mut table = Table::new();
    table.set_format(*format::consts::FORMAT_NO_LINESEP_WITH_TITLE);
    // 设置表头
    table.set_titles(row!["grouping__id","grouping sets"]);
    for set in grouping_sets {
        let grouping_id = _computer_grouping_id(is_new_version, &group_by, &set);
        table.add_row(row![grouping_id, format!("{:?}",set)]);
    }
    table.printstd();
}

fn _computer_grouping_id(is_new_version: bool, group_by: &Vec<String>, grouping_set: &Vec<String>) -> isize {
    let mut grouping_id_str = String::new();
    for x in group_by {
        let exist = grouping_set.contains(&x);
        if is_new_version {
            if exist {
                grouping_id_str.push_str("0")
            } else {
                grouping_id_str.push_str("1")
            }
        } else {
            if exist {
                grouping_id_str.push_str("1")
            } else {
                grouping_id_str.push_str("0")
            }
        }
    }
    if !is_new_version {
        // 反转字符串
        grouping_id_str = reverse_string(grouping_id_str.as_str())
    }

    return isize::from_str_radix(&grouping_id_str, 2).unwrap();
}

需要注意_computer_grouping_id函数的参数为 Vec 的引用,因为group_by在传入函数时所有权会发生移交,在下一次循环时会发生报错

1.5 完整代码

main.rs

use std::io;
use prettytable::{format, row, Table};
use regex::Regex;

fn main() {
    display();
    let version = is_new_version();
    let group_by = parse_group_by();
    let grouping_sets = parse_grouping_sets();
    computer_grouping_id(version, group_by, grouping_sets);
    enter_any_exit();
}


fn is_new_version() -> bool {
    // 接受一行命令行输入
    eprint!("是否启动新版本计算逻辑[y/n]: ");
    let mut is_version_str = String::new();
    match io::stdin().read_line(&mut is_version_str) {
        Ok(_) => {}
        Err(err) => {
            panic!("[{}] - {}", err.kind(), err)
        }
    }
    is_version_str = is_version_str.trim().to_lowercase();
    if is_version_str == "y" {
        return true;
    } else if is_version_str == "n" {
        return false;
    } else {
        panic!("期望 y or n but get {}", is_version_str)
    }
}

fn parse_group_by() -> Vec<String> {
    eprint!("输入 group by 后的列表[以;结尾]: ");
    let group_by_str = read_with_char(";");
    let mut vec: Vec<String> = Vec::new();
    for x in group_by_str.split(",") {
        vec.push(x.trim().to_lowercase())
    }
    return vec;
}

fn parse_grouping_sets() -> Vec<Vec<String>> {
    println!("输入 grouping sets (*)的列表[以;结尾,按回车分割]: ");
    // 使用正则提取
    let re = Regex::new(r"\(([^)]*)\)").unwrap();
    let mut result = Vec::new();
    let grouping_sets = read_with_char(";");
    for x in grouping_sets.split("\n") {
        for cap in re.captures_iter(x) {
            let mut vec = Vec::new();
            for set in cap[1].split(",") {
                vec.push(set.trim().to_lowercase())
            }
            result.push(vec)
        }
    }
    return result;
}

fn computer_grouping_id(is_new_version: bool, group_by: Vec<String>, grouping_sets: Vec<Vec<String>>) {
    let mut table = Table::new();
    table.set_format(*format::consts::FORMAT_NO_LINESEP_WITH_TITLE);
    // 设置表头
    table.set_titles(row!["grouping__id","grouping sets"]);
    for set in grouping_sets {
        let grouping_id = _computer_grouping_id(is_new_version, &group_by, &set);
        table.add_row(row![grouping_id, format!("{:?}",set)]);
    }
    table.printstd();
}

fn _computer_grouping_id(is_new_version: bool, group_by: &Vec<String>, grouping_set: &Vec<String>) -> isize {
    let mut grouping_id_str = String::new();
    for x in group_by {
        let exist = grouping_set.contains(&x);
        if is_new_version {
            if exist {
                grouping_id_str.push_str("0")
            } else {
                grouping_id_str.push_str("1")
            }
        } else {
            if exist {
                grouping_id_str.push_str("1")
            } else {
                grouping_id_str.push_str("0")
            }
        }
    }
    if !is_new_version {
        // 反转字符串
        grouping_id_str = reverse_string(grouping_id_str.as_str())
    }

    return isize::from_str_radix(&grouping_id_str, 2).unwrap();
}

fn reverse_string(input: &str) -> String {
    input.chars().rev().collect()
}

fn read_with_char(end: &str) -> String {
    let mut result = String::new();
    loop {
        let mut line = String::new();
        match io::stdin().read_line(&mut line) {
            Ok(_) => {
                result.push_str(line.as_str());
                if line.trim().ends_with(end) {
                    result = result.replace(end, "");
                    break;
                }
            }
            Err(err) => {
                panic!("[{}] - {}", err.kind(), err)
            }
        }
    }
    return result;
}

fn enter_any_exit() {
    println!("按任意键结束...");
    let mut flag = String::new();
    io::stdin().read_line(&mut flag).unwrap();
}


fn display() {
    println!("fast-gid v1.0 for rust");
    println!("快速计算 hsql/sparkSql 高阶聚合的 grouping__id");
    println!();
}

Github 仓库地址:https://github.com/kpretty/fast-gid

二、使用

usage