46 lines
1.6 KiB
TypeScript
46 lines
1.6 KiB
TypeScript
import { cv } from "../../cv";
|
|
import { ImageCropOption, ImageSource, Model } from "../common/model";
|
|
import { convertImage } from "../common/processors";
|
|
|
|
interface GenderAgePredictOption extends ImageCropOption {
|
|
}
|
|
|
|
export interface GenderAgePredictResult {
|
|
gender: "M" | "F"
|
|
age: number
|
|
}
|
|
|
|
const MODEL_URL_CONFIG = {
|
|
INSIGHT_GENDER_AGE_ONNX: `https://www.modelscope.cn/models/luyizhi/basic_cv/resolve/master/faceattr/insight_gender_age.onnx`,
|
|
};
|
|
|
|
export class GenderAge extends Model {
|
|
public static async load(type?: keyof typeof MODEL_URL_CONFIG) {
|
|
return this.cacheModel(MODEL_URL_CONFIG[type ?? "INSIGHT_GENDER_AGE_ONNX"], { createModel: true }).then(r => r.model);
|
|
}
|
|
|
|
public predict(image: ImageSource, option?: GenderAgePredictOption) { return Model.resolveImage(image, im => this.doPredict(im, option)); }
|
|
|
|
private async doPredict(image: cv.Mat, option?: GenderAgePredictOption): Promise<GenderAgePredictResult> {
|
|
const input = this.input;
|
|
const output = this.output;
|
|
if (option?.crop) image = image.crop(option.crop.sx, option.crop.sy, option.crop.sw, option.crop.sh);
|
|
image = image.resize(input.shape[3], input.shape[2]);
|
|
|
|
const nchwImage = convertImage(image.data, { sourceImageFormat: "bgr", targetColorFormat: "rgb", targetShapeFormat: "nchw", targetNormalize: { mean: [0], std: [1] } });
|
|
|
|
const result = await this.session.run({
|
|
[input.name]: {
|
|
shape: [1, 3, input.shape[2], input.shape[3]],
|
|
data: nchwImage,
|
|
type: "float32",
|
|
}
|
|
}).then(res => res[output.name]);
|
|
|
|
return {
|
|
gender: result[0] > result[1] ? "F" : "M",
|
|
age: parseInt(result[2] * 100 as any),
|
|
}
|
|
}
|
|
}
|