STM32N6实时手写数字识别:MCU AI推理实战指南
手写数字识别这个题目,只要是做过机器学习的朋友,基本不会陌生。哪怕没亲手训练过模型,十有八九也见过那个经典到几乎成为符号的数据集——黑底白字或者白底黑字,一张张小尺寸的数字图片,最后让模型去判断究竟是 0 还是 9。它已经在很长时间里,都是很多人接触图像分类和神经网络时的第一课。但也正因为足够经典、足够简单、足够直观,它反倒特别适合拿来验证嵌入式 AI 这条链路到底能不能走通。这件事的链路非常完整:从手写采集、预处理、模型推理到结果输出,一个环节都不少,而任务本身的复杂度又不至于把人一上来就劝退。
实时推理和采集完再推理是两码事。前者要求算力能同时撑住书写的流畅度和推理的准确度。所以,当我想认真看看 STM32N6 这颗芯片究竟能不能把实时推理这件事扛起来时,就借这个最小但完整的例子,来弄清楚一件更重要的事:一块 MCU,到底能不能开始像一个真正的边缘 AI 节点那样工作。
STM32 这个名号,做嵌入式的应该都非常熟悉。过去它更多出现在控制、传感、通信这些典型场景里,是相当经典的一类微控制器平台,开发生态成熟,资料多,社区也大。而 STM32N6 有意思的地方在于,它开始明显往边缘 AI 这个方向再走一步。它不只是让传统 MCU 任务跑得更快,而是在尝试回答一个很现实的问题:如果希望一些简单但有价值的推理任务不依赖云端,直接在本地设备上实时完成,那 STM32 这一代芯片能不能把这件事做得足够像样。
1、模型制作
如果做过一点深度学习就会知道,手写数字识别这件事本身并不难,甚至可以说是图像分类里最经典的入门题。常见的数据集就是 MNIST,图片尺寸只有 28×28,内容也很单纯,就是 0 到 9 这十个数字的手写体。也正因为它足够标准,大家通常不会把重点放在“能不能识别出来”上,而是会更关心另一件事:在识别准确率还不错的前提下,模型能不能尽可能小、尽可能轻、尽可能适合部署到端侧设备上。
在 PC 上跑一个大模型把数字认出来,没什么稀奇的。真正有意思的是,让它跑在 STM32N6 这种嵌入式平台上,而且还希望是实时的。那模型就不能只顾着准,还得顾内存、顾算力、顾部署格式,甚至连输入输出的数据类型都得提前想好。
所以在模型结构上,没有走特别复杂的路线,而是选了一个很典型、也很适合这个任务的小型 CNN。输入是 28×28 的灰度图,前面用两层卷积去提取边缘和纹理特征,再接一次池化做降采样,后面再补一层更高层的卷积去抓数字的整体轮廓,最后经过全连接层输出 10 个类别的概率。这个思路其实很朴素,说白了,就是先看笔画,再看形状,最后判断是几。
这里先导出了普通的 TFLite 浮点模型,然后又进一步做了 INT8 量化,最后得到一个 mnist_model_int8.tflite 文件。目的非常明确:尽可能把模型压小,同时让它更适合在嵌入式硬件上运行。
2、STM32N6配置
回到 STM32N6 这块板子,它之所以适合做这类实时手写数字识别,一个很重要的原因就是整个平台在存储访问和计算能力这两件事上都给得比较足。比如它使用 200MHz 的 XSPI 来做外部存储通讯,这看起来像是个很底层的配置,但真做部署的时候你会发现它其实非常关键。因为模型、权重、运行时缓冲区,很多时候都离不开外部存储的配合。如果这条链路不够顺,前面模型做得再轻,后面推理也很容易卡在数据吞吐上。
800MHz 的 NPU 配合 800MHz 的 CPU 主频,为端侧 AI 推理留出了空间。CPU 可以负责图像采集、预处理、任务调度这些外围工作,NPU 则专注神经网络推理本身。整个系统的分工就比较清楚了——不是所有事情都压在一个核上硬扛,而是让不同模块各干各的,这也是实时推理能够跑顺的重要基础。
3、画板逻辑实现
手写数字识别这件事,核心不是把一张现成的 28×28 图片丢给模型,而是为了让用户真的能在屏幕上写一个数字,再把它变成模型能识别的数据。也就是说,在 STM32N6 上,模型前面还需要先搭一个画板。
这套画板实现的核心思路其实很直接:先在屏幕上划出一个 28×28 的离散网格,让它和模型输入一一对应。代码里左侧画板区域被固定成一个 480×480 的正方形,然后再均匀切成 28×28 个小格。这样做有个很实际的好处:用户在屏幕上每写一笔,最终落下来的不是一堆连续像素,而是直接映射到模型真正要吃的 28×28 输入空间里。换句话说,画板本身就已经在帮后面的推理做第一层数据规整了。
顺着这个思路,代码里最重要的一块就是这个 28×28 的二维数组。每个格子本质上都对应模型输入里的一个像素点,默认是未填充状态,用户触摸之后再把对应位置置成有效值。这样一来,屏幕上的书写动作就被收敛成了一张标准尺寸的二值图。这个设计思路很讨巧——没有绕很远,不是先在大分辨率画布上画完再复杂缩放,而是从一开始就把输入目标锁死在模型真正需要的尺寸上,路径非常短,也比较稳。
当然,只把一个触摸点映射成单个格子,写出来的数字会很细,很容易断。所以在具体处理触摸的时候,代码里没有只点亮当前命中的一个单元,而是会顺手把周围 3×3 范围内的格子一起填上。这个动作特别像给笔画加了一个最小号的画笔粗细。这样做之后,用户写出来的数字会更连贯,边缘也不会太锯齿,整体更接近模型训练时看到的那类手写数字分布。坦率地讲,这一步非常值——几乎没增加什么复杂度,却明显改善了输入质量。
void DigitBoard_HandleTouch(uint16_t x, uint16_t y, uint8_t touching)
{
DigitBoard_TargetTypeDef target;
if (touching == 0U)
{
s_last_touch_target = DIGIT_TARGET_NONE;
return;
}
target = DigitBoard_GetTouchTarget(x, y);
if (target == DIGIT_TARGET_BOARD)
{
DigitBoard_SetCellFromTouch(x, y);
s_last_touch_target = target;
return;
}
if ((target == DIGIT_TARGET_CLEAR_BUTTON) && (s_last_touch_target != DIGIT_TARGET_CLEAR_BUTTON))
{
DigitBoard_ClearState();
}
s_last_touch_target = target;
}再往下,画板能用起来,离不开触摸芯片这一层。这里项目里用的是 GT911。简单讲就是把触摸控制器采上来的原始坐标,通过 I2C 读出来,再映射成 LCD 上真正可用的屏幕坐标。这里面还有一个挺实际的处理:因为触摸面板的坐标方向和屏幕显示方向不一定完全一致,所以代码里专门做了横竖方向的判断和缩放映射,把原始点转换成 800×480 的显示坐标。如果这里没对齐,后面手指落点和屏幕显示位置一错位,整个画板体验会立刻变得很别扭。
在 UI 结构上,这个画板也没有做得很花哨。左边是 28×28 的书写区域,右边是结果显示区域,底下再给一个 Clear 按钮负责清空状态。这个布局其实很适合这种 Demo——它够直接,读者或者演示者一上手就知道要在哪写,写完看哪,错了点哪清掉。很多嵌入式项目做 Demo 的时候,功能本身不复杂,但交互路径绕,最后显得整个系统很笨。这里画板的结构比较克制,反而把重点都让给了识别本身。
所以回头看,画板实现这一步,本质上做了三件事。第一,把用户的连续手写动作约束成模型真正需要的 28×28 输入格式。第二,用 GT911 把触摸坐标稳定读出来,并且正确映射到显示空间。第三,用一个尽量简单但足够顺手的交互界面,把书写、清空和结果显示串成一个完整闭环。
4、神经网络处理
float probs[10];
printf("raw_out:");
for (int i = 0; i < 10; i++)
{
printf(" %d", (int)((int8_t)buffer_out[i]));
probs[i] = 0.00390625f * ((float)((int8_t)buffer_out[i]) + 128.0f);
}
printf("rn");
DigitBoard_ShowPredictions(probs);
}神经网络的部分,除了 CubeMX 生成的运算内容,主要就是后处理。因为这个模型的输出是 10 个 int8 类型的数值,我们需要对其进行归一化,将 int8 类型数值映射到 0~1 之间的概率值。
计算出 10 个数字概率后,再使用绘图函数将其绘制出来。
void MX_X_CUBE_AI_Process(void)
{
/* USER CODE BEGIN 6 */
LL_ATON_RT_RetValues_t ll_aton_rt_ret = LL_ATON_RT_DONE;
const LL_Buffer_InfoTypeDef * ibuffersInfos = NN_Interface_Default.input_buffers_info();
const LL_Buffer_InfoTypeDef * obuffersInfos = NN_Interface_Default.output_buffers_info();
buffer_in = (uint8_t *)LL_Buffer_addr_start(&ibuffersInfos[0]);
buffer_out = (uint8_t *)LL_Buffer_addr_start(&obuffersInfos[0]);
LL_ATON_RT_RuntimeInit();
memcpy(buffer_in,s_digit_board,28*28);
SCB_CleanDCache_by_Addr(buffer_in, 28 * 28);
// run 10 inferences
for (int inferenceNb = 0; inferenceNb<1; ++inferenceNb) {
/* ------------- */
/* - Inference - */
/* ------------- */
/* Pre-process and fill the input buffer */
//_pre_process(buffer_in);
/* Perform the inference */
LL_ATON_RT_Init_Network(&NN_Instance_Default); // Initialize passed network instance object
do {
/* Execute first/next step */
ll_aton_rt_ret = LL_ATON_RT_RunEpochBlock(&NN_Instance_Default);
/* Wait for next event */
if (ll_aton_rt_ret == LL_ATON_RT_WFE) {
LL_ATON_OSAL_WFE();
}
} while (ll_aton_rt_ret != LL_ATON_RT_DONE);
SCB_InvalidateDCache_by_Addr(buffer_out, 10);
/* Post-process the output buffer */
/* Invalidate the associated CPU cache region if requested */
{
{
float probs[10];
printf("raw_out:");
for (int i = 0; i < 10; i++)
{
printf(" %d", (int)((int8_t)buffer_out[i]));
probs[i] = 0.00390625f * ((float)((int8_t)buffer_out[i]) + 128.0f);
}
printf("rn");
DigitBoard_ShowPredictions(probs);
}
}
LL_ATON_RT_DeInit_Network(&NN_Instance_Default);
}
LL_ATON_RT_RuntimeDeInit();
/* USER CODE END 6 */
}