import { cv } from "../../cv"; import { ImageCropOption, ImageSource, Model } from "../common/model"; import { convertImage } from "../common/processors"; interface GenderAgePredictOption extends ImageCropOption { } export interface GenderAgePredictResult { gender: "M" | "F" age: number } const MODEL_URL_CONFIG = { INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`, }; export class GenderAge extends Model { public static async load(type?: keyof typeof MODEL_URL_CONFIG) { return this.cacheModel(MODEL_URL_CONFIG[type ?? "INSIGHT_GENDER_AGE_ONNX"], { createModel: true }).then(r => r.model); } public predict(image: ImageSource, option?: GenderAgePredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); } private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise { const input = this.input; const output = this.output; if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh); image = image.resize(input.shape[3], input.shape[2]); const nchwImage = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } }); const result = await this.session.run({ [input.name]: { shape: [1, 3, input.shape[2], input.shape[3]], data: nchwImage, type: "float32", } }).then(res => res[output.name]); return { gender: result[0] > result[1] ? "F" : "M", age: parseInt(result[2] * 100 as any), } } }