BiRefNet
地位: 目前开源界最高清、边缘最锐利的去背景模型(2024年爆火)。 特点: 它通过极其复杂的双边参考架构,彻底解决了以前抠图模型在处理极高分辨率图像时“边缘糊成一团”的问题。对复杂背景下的发丝和半透明物体处理极为惊艳,目前在 ComfyUI 节点中几乎取代了以前的所有老模型。
以下是示例
效果演示
拖动分割线查看 BiRefNet 的抠图效果:

BiRefNet 抠图结果

原始图像
遮罩生成 (Mask)
模型生成的中间权重遮罩:

生成遮罩 (Mask)

原始图像
代码实现
BiRefNet
import os
import time
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# 配置路径
MODEL_PATH = '/mnt/sda/models/BiRefNet'
INPUT_DIR = './images'
OUTPUT_DIR = './results'
def process_images():
'''
主处理函数:加载模型,遍历图片并执行推理
'''
# 确保输出目录存在
mask_dir = os.path.join(OUTPUT_DIR, 'mask')
rembg_dir = os.path.join(OUTPUT_DIR, 'rembg')
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(rembg_dir, exist_ok=True)
# 1. 加载模型
print(f'正在从 {MODEL_PATH} 加载模型...')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu':
print('警告: 未检测到可用 GPU 或驱动版本过低,将使用 CPU 运行。推理速度会显著变慢。')
try:
model = AutoModelForImageSegmentation.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
if device == 'cpu':
model = model.float() # CPU 模式下强制使用 float32,避免 float16 兼容性问题
model.to(device)
model.eval()
except Exception as e:
print(f'模型加载失败: {e}')
return
# 2. 预处理配置
# BiRefNet 标准输入分辨率为 1024x1024
input_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 3. 获取待处理图片列表
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
if not os.path.exists(INPUT_DIR):
print(f'目录 {INPUT_DIR} 不存在,请先创建并放入图片。')
return
image_list = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(image_extensions)]
if not image_list:
print(f'在 {INPUT_DIR} 目录下未找到支持的图片文件。')
return
print(f'共找到 {len(image_list)} 张图片,开始推理...')
# 4. 循环处理
for img_name in image_list:
img_path = os.path.join(INPUT_DIR, img_name)
print(f'\n--- 正在处理: {img_name} ---')
try:
# 读取并准备图片
print('1. 读取图片中...')
input_image = Image.open(img_path).convert('RGB')
origin_size = input_image.size # (W, H)
# 预处理转换
print(f'2. 预处理中 (目标尺寸: {input_size})...')
input_tensor = transform_image(input_image).unsqueeze(0).to(device)
# 自动匹配模型参数精度
input_tensor = input_tensor.to(next(model.parameters()).dtype)
# 模型推理
print(f'3. 模型推理中 (运行设备: {device})...')
start_time = time.time()
with torch.no_grad():
# BiRefNet 返回图像金字塔的预测结果,取最后一层(精细层)
preds = model(input_tensor)[-1].sigmoid().cpu().float()
end_time = time.time()
print(f'4. 推理完成 (耗时: {end_time - start_time:.2f}s),正在提取遮罩...')
# 提取遮罩 (Mask)
mask_tensor = preds[0].squeeze()
mask_pil = transforms.ToPILImage()(mask_tensor)
# 缩放至原图大小
mask_pil = mask_pil.resize(origin_size, Image.BILINEAR)
# 保存遮罩结果
save_name_base = os.path.splitext(img_name)[0]
mask_pil.save(os.path.join(mask_dir, f'{save_name_base}_mask.png'))
# 生成抠图 (Rembg)
print('5. 正在生成并保存抠图结果...')
rembg_image = input_image.copy().convert('RGBA')
rembg_image.putalpha(mask_pil)
rembg_image.save(os.path.join(rembg_dir, f'{save_name_base}_rembg.png'))
print(f'>>> 处理完成: {img_name}')
except Exception as e:
print(f'处理图片 {img_name} 时发生错误: {e}')
print(f'全部处理完成!\n结果保存在: \n- 遮罩: {mask_dir}\n- 抠图: {rembg_dir}')
if __name__ == '__main__':
process_images()