axum 操作 Postgres 数据库

1369370
2021/11/13 08:07:21

PostgreSQL 是一款天然支持异步操作的高性能开源关系型数据库。本章将讨论如何在 axum 中使用 PostgreSQL。包括:数据的增加、修改、删除、查找以及开始事务保证业务的原子性。

如果你对 PostgreSQL 不是很了解,可以通过PostgreSQL 轻松学网站进行学习。

ElephantSQL提供了免费的 PostgreSQL 数据库,对于学习而言,它提供的免费实例就足够了。你可以创建多个免费实例,并且它的实例可以选择包括香港在内的多个节点。

依赖

本章代码涉及多个依赖,下面对各依赖项进行简单说明:

  • deadpool-postgres:连接池

  • tokio-pg-mapper:自动将数据库的SELECT语句与结构体进行映射

  • tokio-pg-mapper-derive:使用 derive 简化数据库与结构体的映射

准备工作

开始之前,你需要做好以下准备工作

  • 有一个 PostgreSQL 实例,如果你不想在本地安装,可以使用 ElephantSQL 提供的免费实例。

  • 导入SQL 文件

该 SQL 文件内容如下:

CREATE TABLE account (
    id SERIAL PRIMARY KEY,
    username VARCHAR(50) NOT NULL,
    balance INTEGER NOT NULL DEFAULT 0,
    UNIQUE(username)
);
INSERT INTO
    account(username, balance)
VALUES
    ('axum.rs', 999999),
    ('foo', 0),
    ('bar', 0);

创建连接池

为了提升并发效率,我们不直接连接数据库,而是通过连接也来维护数据库连接。rust 有很多连接池的库,我们选择的是 deadpool-postgres

创建连接池的方法很多,最方便的莫过于从配置中创建,所以我们从数据库配置开始:

数据库连接池配置

let mut cfg = deadpool_postgres::Config::new();
cfg.user = Some("axum_rs".to_string()); //数据库用户名
cfg.password = Some("axum.rs".to_string()); //数据库密码
cfg.dbname = Some("axum_rs".to_string()); //数据库名称
cfg.host = Some("pg.axum.rs".to_string()); // 数据库主机
cfg.port = Some(5432); //数据库端口
cfg

本示例的配置仅供演示说明,请根据你自身的环境对数据库配置进行修改。

从配置中创建连接池

有了配置之后,可以方便的地从中创建连接池,比如:

let pool = cfg.create_pool(tokio_postgres::NoTls).unwrap();

从连接池中获取数据库连接

接下来,我们调用连接池的get()方法,从中获取一个数据库连接:

let client = pool.get().await.unwrap();

现在,可以开始数据库操作了。

获取数据库连接的完整代码

为了方便你对照学习,现将获取数据库连接的完整代码提供给你:

/// 数据库配置
fn get_cfg() -> deadpool_postgres::Config {
    let mut cfg = deadpool_postgres::Config::new();
    cfg.user = Some("axum_rs".to_string()); //数据库用户名
    cfg.password = Some("axum.rs".to_string()); //数据库密码
    cfg.dbname = Some("axum_rs".to_string()); //数据库名称
    cfg.host = Some("pg.axum.rs".to_string()); // 数据库主机
    cfg.port = Some(5432); //数据库端口
    cfg
}
/// 从连接池中获取数据库连接
async fn get_client() -> Result<deadpool_postgres::Client, String> {
    // 通过配置文件创建连接池
    let pool = get_cfg()
        .create_pool(tokio_postgres::NoTls)
        .map_err(|err| err.to_string())?;
    // 从连接池中获取数据库连接
    pool.get().await.map_err(|err| err.to_string())
}

定义模型

为了与数据表进行映射,我们需要定义结构体。该结构体和数据表字段需要对应好。

#[derive(PostgresMapper, Debug)]
#[pg_mapper(table = "account")]
pub struct Account {
    pub id: i32,
    pub username: String,
    pub balance: i32,
}

眼尖的你可能发现了,除了Debug之外,我们还使用了PostgresMapperpg_mapper

  • PostgresMapper:让你的结构体能使用快速映射方法,将数据库查询结果自动填充到结构体

  • pg_mapper:指定这个结构体关联的数据表

插入数据

首先,我们来看一下如何向数据库中添加一条新数据:

async fn insert(Path(username): Path<String>) -> Result<&'static str, String> {
    let create_user = CreateAccount {
        username,
        balance: 0,
    };
    let client = get_client().await?;
    let stmt = client
        .prepare("INSERT INTO account (username, balance) VALUES ($1, $2)")
        .await
        .map_err(|err| err.to_string())?;
    let rows = client
        .execute(&stmt, &[&create_user.username, &create_user.balance])
        .await
        .map_err(|err| err.to_string())?;
    if rows < 1 {
        return Err("Insert account failed".to_string());
    }
    Ok("Successfully insert account")
}

让我们对这个函数进行分解说明:

首先,创建了一个 CreateAccount的实例,它是一个结构体,该结构体的定义如下:

pub struct CreateAccount {
    pub username: String,
    pub balance: i32,
}

然后,通过 get_clinet() 函数从连接池中获取连接。

之后,通过client.prepare() 方法获取一个“语句”对象。prepare方法会对 SQL 语句进行预编译,这是防止 SQL 注入最有效的方法。

再然后,调用client.execute()执行 SQL 语句,它将返回受影响的行数。

通过判断受影响的行数就能知道这个 SQL 语句是否执行成功。

execute()

对于 INSERT/UPDATE/DELETE 等语句,通常使用execute()方法来执行 SQL 语句。它接收两个参数,一个是预编译 SQL 语句之后返回的对象,另一个是 SQL 语句所需要的参数。

通常我们是通过它的返回值来判断 SQL 语句是否成功。

注意,由于 PostgreSQL 不提供类似其它数据库的 LastInsertedID功能,所以使用execute()执行INSERT语句并不能获得最后插入的 ID。

修改数据

我们以修改指定用户余额作为演示:

async fn update(Path((id, balance)): Path<(i32, i32)>) -> Result<&'static str, String> {
    let client = get_client().await?;
    let stmt = client
        .prepare("UPDATE account SET balance=$1 WHERE id=$2")
        .await
        .map_err(|err| err.to_string())?;
    let rows = client
        .execute(&stmt, &[&balance, &id])
        .await
        .map_err(|err| err.to_string())?;
    if rows < 1 {
        return Err("Update account failed".to_string());
    }
    Ok("Successfully update account")
}

和插入数据类似,我们使用的是execute(),不同之处在于,这里的 SQL 语句是UPDATE

删除数据

和修改数据类型,不同之处在于,这里的 SQL 语句是DELETE,如下所示:

async fn delete(Path(id): Path<i32>) -> Result<&'static str, String> {
    let client = get_client().await?;
    let stmt = client
        .prepare("DELETE FROM account WHERE id=$1")
        .await
        .map_err(|err| err.to_string())?;
    let rows = client
        .execute(&stmt, &[&id])
        .await
        .map_err(|err| err.to_string())?;
    if rows < 1 {
        return Err("Delete account failed".to_string());
    }
    Ok("Successfully delete account")
}

查找多条数据

现在我们来看一下怎么查找数据。

async fn list() -> Result<String, String> {
    let client = get_client().await?;
    let stmt = client
        .prepare("SELECT id,username,balance FROM account ORDER BY id DESC")
        .await
        .map_err(|err| err.to_string())?;
    let account_list = client
        .query(&stmt, &[])
        .await
        .map_err(|err| err.to_string())?
        .iter()
        .map(|row| Account::from_row_ref(&row).unwrap())
        .collect::<Vec<Account>>();
    let mut output = Vec::with_capacity(account_list.len());
    for account in account_list.iter() {
        output.push(format!("{:?}", account));
    }
    Ok(output.join("\n"))
}

前两步都一样,获取数据库连接,然后对 SQL 语句进行预编译。

之后,我们通过client.query()方法来执行数据的查询。它是一个迭代器,我们对所有查询到的数据用map()进行处理。

map()里的闭包只有一行:

|row| Account::from_row_ref(&row).unwrap()

它的含义是,将查询到的数据,逐行映射到 Account 结构体。我们可以把它理解为:

|row| {
        Account: {
            id: row.id,
            username: row.username.clone(),
            // ...
        }
}

这里之所以能写成 Account::from_row_ref(&row) 是因为我们使用了 tokio-pg-mapper

之后,使用collect()map() 处理过后的数据变成Vec<Account>集合。

query()

该方法通常用于查询多条数据。配合迭代器和map()collect(),可以很方便、直观、高效地获取到包含满足条件的记录的集合。

实际上,对于单条记录的查询,也可以使用该方法。

另外,如果要获取 PostgreSQL 最后插入的 ID,我们也使用该方法。

查询单条记录

和查询多条记录一样,我们也用 query()

async fn find(Path(id): Path<i32>) -> Result<String, String> {
    let client = get_client().await?;
    let stmt = client
        .prepare("SELECT id,username,balance FROM account WHERE id=$1 ORDER BY id DESC LIMIT 1")
        .await
        .map_err(|err| err.to_string())?;
    let account = client
        .query(&stmt, &[&id])
        .await
        .map_err(|err| err.to_string())?
        .iter()
        .map(|row| Account::from_row_ref(&row).unwrap())
        .collect::<Vec<Account>>()
        .pop()
        .ok_or(format!("Couldn't find account #{}", id))?;
    Ok(format!("{:?}", account))
}

之前的操作和查询多条一毛一样。在 collect() 获取集合后,我们调用了 pop() 方法,如果有记录就能得到一条记录,如果没有记录,就会用后面的ok_or()来返回一个Err<String>,就是我们的报错信息。

事务

我们通过模拟账户之间转账来演示事务的处理。

async fn transfer(
    Path((from_id, to_id, balance)): Path<(i32, i32, i32)>,
) -> Result<&'static str, String> {
    let mut client = get_client().await?;
    let tx = client.transaction().await.map_err(|err| err.to_string())?;

    // 修改出账记录
    let stmt = tx
        .prepare("UPDATE account SET balance=balance-$1 WHERE id=$2 AND balance>=$1")
        .await
        .map_err(|err| err.to_string())?;
    match tx.execute(&stmt, &[&balance, &from_id]).await {
        Ok(_rows) if _rows > 0 => {
            // 检查受影响的行数
            // 如果大于零表示账户存在并且余额足够
            // 不必做其余操作,等待最终的事务提交
        }
        _ => {
            // 回滚事务
            tx.rollback().await.map_err(|err| err.to_string())?;
            // 提前结束函数,将错误信息返回
            return Err("Step 1 failed".to_string());
        }
    };

    // 修改入账记录
    let stmt = tx
        .prepare("UPDATE account SET balance=balance+$1 WHERE id=$2")
        .await
        .map_err(|err| err.to_string())?;
    match tx.execute(&stmt, &[&balance, &to_id]).await {
        Ok(_rows) if _rows > 0 => {
            // 检查受影响的行数
            // 如果大于零表示入账记录修改成功
            // 不必做其余操作,等待最终的事务提交
        }
        _ => {
            // 回滚事务
            tx.rollback().await.map_err(|err| err.to_string())?;
            // 提前结束函数,将错误信息返回
            return Err("Step 2 failed".to_string());
        }
    };

    // 提交事务
    tx.commit().await.map_err(|err| err.to_string())?;
    Ok("Successfully transfer")
}

获取到连接之后,我们并没有像之前那样调用它的 prepare(),而是调用了它的 transaction() 来开启事务。注意,此时要求连接是mut的。

之后就是prepareexecute了。不同的是,我们用的是事务的相关方法,而不是数据库连接的相关方法。

值得注意的是,我们判断execute如果没有成功执行,调用了tx.rollback() 对事务进行回滚,并返回错误信息。

本章讨论了在 axum 集成 PostgreSQL 数据库的方法。完整代码可以在代码仓库找到。

思考题

如何获取 PostgreSQL 最后插入的 ID?

提示:

  • INSERT语句使用RETURNING,你可以查看这里

  • 使用query()而不是execute()(类似查找单条记录)