跳到主要内容

BiRefNet

地位: 目前开源界最高清、边缘最锐利的去背景模型(2024年爆火)。 特点: 它通过极其复杂的双边参考架构,彻底解决了以前抠图模型在处理极高分辨率图像时“边缘糊成一团”的问题。对复杂背景下的发丝和半透明物体处理极为惊艳,目前在 ComfyUI 节点中几乎取代了以前的所有老模型。

以下是示例

效果演示

拖动分割线查看 BiRefNet 的抠图效果:

After
BiRefNet 抠图结果
Before
原始图像

遮罩生成 (Mask)

模型生成的中间权重遮罩:

After
生成遮罩 (Mask)
Before
原始图像

代码实现

BiRefNet
import os
import time
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageSegmentation

# 配置路径
MODEL_PATH = '/mnt/sda/models/BiRefNet'
INPUT_DIR = './images'
OUTPUT_DIR = './results'

def process_images():
'''
主处理函数:加载模型,遍历图片并执行推理
'''
# 确保输出目录存在
mask_dir = os.path.join(OUTPUT_DIR, 'mask')
rembg_dir = os.path.join(OUTPUT_DIR, 'rembg')
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(rembg_dir, exist_ok=True)

# 1. 加载模型
print(f'正在从 {MODEL_PATH} 加载模型...')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cpu':
print('警告: 未检测到可用 GPU 或驱动版本过低,将使用 CPU 运行。推理速度会显著变慢。')

try:
model = AutoModelForImageSegmentation.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
if device == 'cpu':
model = model.float() # CPU 模式下强制使用 float32,避免 float16 兼容性问题
model.to(device)
model.eval()
except Exception as e:
print(f'模型加载失败: {e}')
return

# 2. 预处理配置
# BiRefNet 标准输入分辨率为 1024x1024
input_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(input_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 3. 获取待处理图片列表
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
if not os.path.exists(INPUT_DIR):
print(f'目录 {INPUT_DIR} 不存在,请先创建并放入图片。')
return

image_list = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(image_extensions)]

if not image_list:
print(f'在 {INPUT_DIR} 目录下未找到支持的图片文件。')
return

print(f'共找到 {len(image_list)} 张图片,开始推理...')

# 4. 循环处理
for img_name in image_list:
img_path = os.path.join(INPUT_DIR, img_name)
print(f'\n--- 正在处理: {img_name} ---')

try:
# 读取并准备图片
print('1. 读取图片中...')
input_image = Image.open(img_path).convert('RGB')
origin_size = input_image.size # (W, H)

# 预处理转换
print(f'2. 预处理中 (目标尺寸: {input_size})...')
input_tensor = transform_image(input_image).unsqueeze(0).to(device)
# 自动匹配模型参数精度
input_tensor = input_tensor.to(next(model.parameters()).dtype)

# 模型推理
print(f'3. 模型推理中 (运行设备: {device})...')
start_time = time.time()
with torch.no_grad():
# BiRefNet 返回图像金字塔的预测结果,取最后一层(精细层)
preds = model(input_tensor)[-1].sigmoid().cpu().float()
end_time = time.time()
print(f'4. 推理完成 (耗时: {end_time - start_time:.2f}s),正在提取遮罩...')

# 提取遮罩 (Mask)
mask_tensor = preds[0].squeeze()
mask_pil = transforms.ToPILImage()(mask_tensor)
# 缩放至原图大小
mask_pil = mask_pil.resize(origin_size, Image.BILINEAR)

# 保存遮罩结果
save_name_base = os.path.splitext(img_name)[0]
mask_pil.save(os.path.join(mask_dir, f'{save_name_base}_mask.png'))

# 生成抠图 (Rembg)
print('5. 正在生成并保存抠图结果...')
rembg_image = input_image.copy().convert('RGBA')
rembg_image.putalpha(mask_pil)
rembg_image.save(os.path.join(rembg_dir, f'{save_name_base}_rembg.png'))

print(f'>>> 处理完成: {img_name}')

except Exception as e:
print(f'处理图片 {img_name} 时发生错误: {e}')

print(f'全部处理完成!\n结果保存在: \n- 遮罩: {mask_dir}\n- 抠图: {rembg_dir}')

if __name__ == '__main__':
process_images()