fix: refactor two commands in saving grade to prevent race conditions

This commit is contained in:
2024-01-20 13:13:59 +01:00
parent b897c775a9
commit 6b63fd498a
14 changed files with 140 additions and 97 deletions

View File

@@ -25,7 +25,7 @@ export class IsOrderByConstraint implements ValidatorConstraintInterface {
return true
}
defaultMessage(args: ValidationArguments) {
defaultMessage(_args: ValidationArguments) {
return 'Invalid format. Expected format is "key-direction". Direction must be "asc" or "desc".'
}
}

View File

@@ -10,7 +10,10 @@ import { UsersRepository } from '../users/infrastructure/users.repository'
@Injectable()
export class AuthService {
constructor(private usersRepository: UsersRepository, private prisma: PrismaService) {}
constructor(
private usersRepository: UsersRepository,
private prisma: PrismaService
) {}
async createJwtTokensPair(userId: string, rememberMe?: boolean) {
const accessSecretKey = process.env.ACCESS_JWT_SECRET_KEY

View File

@@ -5,7 +5,10 @@ import * as jwt from 'jsonwebtoken'
import { AuthRepository } from '../infrastructure/auth.repository'
export class RefreshTokenCommand {
constructor(public readonly userId: string, public readonly shortAccessToken: boolean) {}
constructor(
public readonly userId: string,
public readonly shortAccessToken: boolean
) {}
}
@CommandHandler(RefreshTokenCommand)

View File

@@ -5,7 +5,10 @@ import { UsersRepository } from '../../users/infrastructure/users.repository'
import { UsersService } from '../../users/services/users.service'
export class ResetPasswordCommand {
constructor(public readonly resetPasswordToken: string, public readonly newPassword: string) {}
constructor(
public readonly resetPasswordToken: string,
public readonly newPassword: string
) {}
}
@CommandHandler(ResetPasswordCommand)

View File

@@ -1,4 +1,5 @@
import { Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
import { Injectable, InternalServerErrorException, Logger, NotFoundException } from '@nestjs/common'
import { pick } from 'remeda'
import {
createPrismaOrderBy,
@@ -224,4 +225,73 @@ export class CardsRepository {
throw new InternalServerErrorException(e?.message)
}
}
private async getSmartRandomCard(cards: Array<CardWithGrade>): Promise<CardWithGrade> {
const selectionPool: Array<CardWithGrade> = []
cards.forEach(card => {
// Calculate the average grade for the card
const averageGrade =
card.grades.length === 0
? 0
: card.grades.reduce((acc, grade) => acc + grade.grade, 0) / card.grades.length
// Calculate weight for the card, higher weight for lower grade card
const weight = 6 - averageGrade
// Add the card to the selection pool `weight` times
for (let i = 0; i < weight; i++) {
selectionPool.push(card)
}
})
return selectionPool[Math.floor(Math.random() * selectionPool.length)]
}
private async getNotDuplicateRandomCard(
cards: Array<CardWithGrade>,
previousCardId: string
): Promise<CardWithGrade> {
const randomCard = await this.getSmartRandomCard(cards)
if (!randomCard) {
this.logger.error(`No cards found in deck}`, {
previousCardId,
randomCard,
cards,
})
throw new NotFoundException(`No cards found in deck`)
}
if (randomCard.id === previousCardId && cards.length !== 1) {
return this.getNotDuplicateRandomCard(cards, previousCardId)
}
return randomCard
}
async getRandomCardInDeck(deckId: string, userId: string, previousCardId: string) {
const cards = await this.findCardsByDeckIdWithGrade(userId, deckId)
if (!cards.length) {
throw new NotFoundException(`No cards found in deck with id ${deckId}`)
}
const smartRandomCard = await this.getNotDuplicateRandomCard(cards, previousCardId)
return {
...pick(smartRandomCard, [
'id',
'question',
'answer',
'deckId',
'questionImg',
'answerImg',
'questionVideo',
'answerVideo',
'created',
'updated',
'shots',
]),
grade: smartRandomCard.grades[0]?.grade || 0,
}
}
}

View File

@@ -24,7 +24,10 @@ export class ResultNotification<T = null> {
}
export class NotificationExtension {
constructor(public message: string, public key: string | null) {}
constructor(
public message: string,
public key: string | null
) {}
}
export class DomainResultNotification<TData = null> extends ResultNotification<TData> {

View File

@@ -6,7 +6,10 @@ import { validationErrorsMapper, ValidationPipeErrorType } from '../../../settin
import { DomainResultNotification, ResultNotification } from './notification'
export class DomainError extends Error {
constructor(message: string, public resultNotification: ResultNotification) {
constructor(
message: string,
public resultNotification: ResultNotification
) {
super(message)
}
}

View File

@@ -192,13 +192,9 @@ export class DecksController {
description: 'Save the grade of a card',
summary: 'Save the grade of a card',
})
async saveGrade(@Param('id') deckId: string, @Req() req, @Body() body: SaveGradeDto) {
const saved = await this.commandBus.execute(
new SaveGradeCommand(req.user.id, { cardId: body.cardId, grade: body.grade })
)
async saveGrade(@Req() req, @Body() body: SaveGradeDto) {
return await this.commandBus.execute(
new GetRandomCardInDeckCommand(req.user.id, saved.deckId, saved.id)
new SaveGradeCommand(req.user.id, { cardId: body.cardId, grade: body.grade })
)
}
}

View File

@@ -6,7 +6,10 @@ import { Deck } from '../entities/deck.entity'
import { DecksRepository } from '../infrastructure/decks.repository'
export class CreateDeckCommand {
constructor(public readonly deck: CreateDeckDto, public readonly cover: Express.Multer.File) {}
constructor(
public readonly deck: CreateDeckDto,
public readonly cover: Express.Multer.File
) {}
}
@CommandHandler(CreateDeckCommand)

View File

@@ -4,7 +4,10 @@ import { CommandHandler, ICommandHandler } from '@nestjs/cqrs'
import { DecksRepository } from '../infrastructure/decks.repository'
export class DeleteDeckByIdCommand {
constructor(public readonly id: string, public readonly userId: string) {}
constructor(
public readonly id: string,
public readonly userId: string
) {}
}
@CommandHandler(DeleteDeckByIdCommand)

View File

@@ -1,7 +1,5 @@
import { ForbiddenException, Logger, NotFoundException } from '@nestjs/common'
import { CommandHandler, ICommandHandler } from '@nestjs/cqrs'
import { Prisma } from '@prisma/client'
import { pick } from 'remeda'
import { CardsRepository } from '../../cards/infrastructure/cards.repository'
import { DecksRepository } from '../infrastructure/decks.repository'
@@ -14,58 +12,15 @@ export class GetRandomCardInDeckCommand {
) {}
}
type CardWithGrade = Prisma.cardGetPayload<{ include: { grades: true } }>
@CommandHandler(GetRandomCardInDeckCommand)
export class GetRandomCardInDeckHandler implements ICommandHandler<GetRandomCardInDeckCommand> {
logger = new Logger(GetRandomCardInDeckHandler.name)
constructor(
private readonly cardsRepository: CardsRepository,
private readonly decksRepository: DecksRepository
) {}
private async getSmartRandomCard(cards: Array<CardWithGrade>): Promise<CardWithGrade> {
const selectionPool: Array<CardWithGrade> = []
cards.forEach(card => {
// Calculate the average grade for the card
const averageGrade =
card.grades.length === 0
? 0
: card.grades.reduce((acc, grade) => acc + grade.grade, 0) / card.grades.length
// Calculate weight for the card, higher weight for lower grade card
const weight = 6 - averageGrade
// Add the card to the selection pool `weight` times
for (let i = 0; i < weight; i++) {
selectionPool.push(card)
}
})
return selectionPool[Math.floor(Math.random() * selectionPool.length)]
}
private async getNotDuplicateRandomCard(
cards: Array<CardWithGrade>,
previousCardId: string
): Promise<CardWithGrade> {
const randomCard = await this.getSmartRandomCard(cards)
if (!randomCard) {
this.logger.error(`No cards found in deck}`, {
previousCardId,
randomCard,
cards,
})
throw new NotFoundException(`No cards found in deck`)
}
if (randomCard.id === previousCardId && cards.length !== 1) {
return this.getNotDuplicateRandomCard(cards, previousCardId)
}
return randomCard
}
async execute(command: GetRandomCardInDeckCommand) {
const deck = await this.decksRepository.findDeckById(command.deckId)
@@ -74,32 +29,11 @@ export class GetRandomCardInDeckHandler implements ICommandHandler<GetRandomCard
if (deck.userId !== command.userId && deck.isPrivate) {
throw new ForbiddenException(`You can't get a private deck that you don't own`)
}
const cards = await this.cardsRepository.findCardsByDeckIdWithGrade(
return await this.cardsRepository.getRandomCardInDeck(
command.deckId,
command.userId,
command.deckId
command.previousCardId
)
if (!cards.length) {
throw new NotFoundException(`No cards found in deck with id ${command.deckId}`)
}
const smartRandomCard = await this.getNotDuplicateRandomCard(cards, command.previousCardId)
return {
...pick(smartRandomCard, [
'id',
'question',
'answer',
'deckId',
'questionImg',
'answerImg',
'questionVideo',
'answerVideo',
'created',
'updated',
'shots',
]),
grade: smartRandomCard.grades[0]?.grade || 0,
}
}
}

View File

@@ -1,6 +1,12 @@
import { ForbiddenException, NotFoundException } from '@nestjs/common'
import {
ForbiddenException,
InternalServerErrorException,
Logger,
NotFoundException,
} from '@nestjs/common'
import { CommandHandler, ICommandHandler } from '@nestjs/cqrs'
import { CardsRepository } from '../../cards/infrastructure/cards.repository'
import { DecksRepository } from '../infrastructure/decks.repository'
import { GradesRepository } from '../infrastructure/grades.repository'
@@ -16,9 +22,12 @@ export class SaveGradeCommand {
@CommandHandler(SaveGradeCommand)
export class SaveGradeHandler implements ICommandHandler<SaveGradeCommand> {
private readonly logger = new Logger(SaveGradeHandler.name)
constructor(
private readonly decksRepository: DecksRepository,
private readonly gradesRepository: GradesRepository
private readonly gradesRepository: GradesRepository,
private readonly cardsRepository: CardsRepository
) {}
async execute(command: SaveGradeCommand) {
@@ -28,14 +37,21 @@ export class SaveGradeHandler implements ICommandHandler<SaveGradeCommand> {
throw new NotFoundException(`Deck containing card with id ${command.args.cardId} not found`)
if (deck.userId !== command.userId && deck.isPrivate) {
throw new ForbiddenException(`You can't save cards to a private deck that you don't own`)
throw new ForbiddenException(`You can't save cards to a private deck that you don't own`)
}
return await this.gradesRepository.createGrade({
userId: command.userId,
grade: command.args.grade,
cardId: command.args.cardId,
deckId: deck.id,
})
try {
await this.gradesRepository.createGrade({
userId: command.userId,
grade: command.args.grade,
cardId: command.args.cardId,
deckId: deck.id,
})
} catch (e) {
this.logger.error(e)
throw new InternalServerErrorException(e?.message)
}
return this.cardsRepository.getRandomCardInDeck(deck.id, command.userId, command.args.cardId)
}
}

View File

@@ -21,7 +21,10 @@ import { UsersService } from '../services/users.service'
@ApiTags('Admin')
@Controller('users')
export class UsersController {
constructor(private usersService: UsersService, private commandBus: CommandBus) {}
constructor(
private usersService: UsersService,
private commandBus: CommandBus
) {}
@Get()
async findAll(@Query() query) {

View File

@@ -38,7 +38,10 @@ class AuthSettings {
}
export class AppSettings {
constructor(public env: EnvironmentSettings, public auth: AuthSettings) {}
constructor(
public env: EnvironmentSettings,
public auth: AuthSettings
) {}
}
const env = new EnvironmentSettings((process.env.NODE_ENV || 'DEVELOPMENT') as EnvironmentsTypes)
const auth = new AuthSettings(process.env)