diff --git a/packages/cli/src/actions/db.ts b/packages/cli/src/actions/db.ts index 6ffa7003c..592f9888d 100644 --- a/packages/cli/src/actions/db.ts +++ b/packages/cli/src/actions/db.ts @@ -1,5 +1,5 @@ import { formatDocument, ZModelCodeGenerator } from '@zenstackhq/language'; -import { DataModel, Enum, type Model } from '@zenstackhq/language/ast'; +import { DataModel, Enum, isDataField, type DataField, type Model } from '@zenstackhq/language/ast'; import colors from 'colors'; import fs from 'node:fs'; import path from 'node:path'; @@ -14,7 +14,7 @@ import { } from './action-utils'; import { consolidateEnums, syncEnums, syncRelation, syncTable, type Relation } from './pull'; import { providers as pullProviders } from './pull/provider'; -import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, isDatabaseManagedAttribute } from './pull/utils'; +import { getDatasource, getDbName, getRelationFieldsKey, getRelationFkName, getRelationName, isDatabaseManagedAttribute } from './pull/utils'; import type { DataSourceProviderType } from '@zenstackhq/schema'; import { CliError } from '../cli-error'; @@ -35,6 +35,25 @@ export type PullOptions = { indent: number; }; +function hasRelationFieldsArg(field: DataField) { + const relationAttr = field.attributes.find((a) => a.decl.ref?.name === '@relation'); + return !!relationAttr?.args.some((a) => a.name === 'fields'); +} + +function getReferencedModelName(field: DataField) { + return field.type.reference?.ref ? getDbName(field.type.reference.ref) : undefined; +} + +function matchesRelationNameFallback(field: DataField, relationName: string, candidate: DataField) { + const referencedModelName = getReferencedModelName(field); + return ( + !!referencedModelName && + getRelationName(candidate) === relationName && + hasRelationFieldsArg(candidate) === hasRelationFieldsArg(field) && + getReferencedModelName(candidate) === referencedModelName + ); +} + /** * CLI action for db related commands */ @@ -283,46 +302,52 @@ async function runPull(options: PullOptions) { } newDataModel.fields.forEach((f) => { - // Prioritized matching: exact db name > relation fields key > relation FK name > type reference + // Prioritized matching: exact db name > relation fields key > relation FK name > relation name > type reference let originalFields = originalDataModel.fields.filter((d) => getDbName(d) === getDbName(f)); - // If this is a back-reference relation field (has @relation but no `fields` arg), silently skip - const isRelationField = - f.$type === 'DataField' && !!(f as any).attributes?.some((a: any) => a?.decl?.ref?.name === '@relation'); - if (originalFields.length === 0 && isRelationField && !getRelationFieldsKey(f as any)) { - return; - } - if (originalFields.length === 0) { // Try matching by relation fields key (the `fields` attribute in @relation) // This matches relation fields by their FK field references - const newFieldsKey = getRelationFieldsKey(f as any); + const newFieldsKey = isDataField(f) ? getRelationFieldsKey(f) : undefined; if (newFieldsKey) { originalFields = originalDataModel.fields.filter( - (d) => getRelationFieldsKey(d as any) === newFieldsKey, + (d) => isDataField(d) && getRelationFieldsKey(d) === newFieldsKey, ); } } if (originalFields.length === 0) { // Try matching by relation FK name (the `map` attribute in @relation) - originalFields = originalDataModel.fields.filter( - (d) => - getRelationFkName(d as any) === getRelationFkName(f as any) && - !!getRelationFkName(d as any) && - !!getRelationFkName(f as any), - ); + const newFkName = isDataField(f) ? getRelationFkName(f) : undefined; + if (newFkName) { + originalFields = originalDataModel.fields.filter( + (d) => isDataField(d) && getRelationFkName(d) === newFkName, + ); + } + } + + if (originalFields.length === 0) { + // Try matching by relation name (the `name` arg in @relation) + // This is essential for back-reference fields that only have a relation name + const newRelName = isDataField(f) ? getRelationName(f) : undefined; + if (newRelName) { + originalFields = originalDataModel.fields.filter( + (d) => + isDataField(d) && + isDataField(f) && + matchesRelationNameFallback(f, newRelName, d), + ); + } } if (originalFields.length === 0) { // Try matching by type reference // We need this because for relations that don't have @relation, we can only check if the original exists by the field type. // Yes, in this case it can potentially result in multiple original fields, but we only want to ensure that at least one relation exists. - // In the future, we might implement some logic to detect how many of these types of relations we need and add/remove fields based on this. originalFields = originalDataModel.fields.filter( (d) => - f.$type === 'DataField' && - d.$type === 'DataField' && + isDataField(f) && + isDataField(d) && f.type.reference?.ref && d.type.reference?.ref && getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref), @@ -332,7 +357,7 @@ async function runPull(options: PullOptions) { if (originalFields.length > 1) { // If this is a back-reference relation field (no `fields` attribute), // silently skip when there are multiple potential matches - const isBackReferenceField = !getRelationFieldsKey(f as any); + const isBackReferenceField = isDataField(f) && !getRelationFieldsKey(f); if (!isBackReferenceField) { console.warn( colors.yellow( @@ -499,31 +524,43 @@ async function runPull(options: PullOptions) { }); originalDataModel.fields .filter((f) => { - // Prioritized matching: exact db name > relation fields key > relation FK name > type reference + // Prioritized matching: exact db name > relation fields key > relation FK name > relation name > type reference const matchByDbName = newDataModel.fields.find((d) => getDbName(d) === getDbName(f)); if (matchByDbName) return false; // Try matching by relation fields key (the `fields` attribute in @relation) - const originalFieldsKey = getRelationFieldsKey(f as any); + const originalFieldsKey = isDataField(f) ? getRelationFieldsKey(f) : undefined; if (originalFieldsKey) { const matchByFieldsKey = newDataModel.fields.find( - (d) => getRelationFieldsKey(d as any) === originalFieldsKey, + (d) => isDataField(d) && getRelationFieldsKey(d) === originalFieldsKey, ); if (matchByFieldsKey) return false; } - const matchByFkName = newDataModel.fields.find( - (d) => - getRelationFkName(d as any) === getRelationFkName(f as any) && - !!getRelationFkName(d as any) && - !!getRelationFkName(f as any), - ); - if (matchByFkName) return false; + const originalFkName = isDataField(f) ? getRelationFkName(f) : undefined; + if (originalFkName) { + const matchByFkName = newDataModel.fields.find( + (d) => isDataField(d) && getRelationFkName(d) === originalFkName, + ); + if (matchByFkName) return false; + } + + // Try matching by relation name (for named back-reference fields) + const originalRelName = isDataField(f) ? getRelationName(f) : undefined; + if (originalRelName) { + const matchByRelName = newDataModel.fields.find( + (d) => + isDataField(d) && + isDataField(f) && + matchesRelationNameFallback(f, originalRelName, d), + ); + if (matchByRelName) return false; + } const matchByTypeRef = newDataModel.fields.find( (d) => - f.$type === 'DataField' && - d.$type === 'DataField' && + isDataField(f) && + isDataField(d) && f.type.reference?.ref && d.type.reference?.ref && getDbName(f.type.reference.ref) === getDbName(d.type.reference.ref), diff --git a/packages/cli/src/actions/pull/utils.ts b/packages/cli/src/actions/pull/utils.ts index 9ec056bc4..04e565e31 100644 --- a/packages/cli/src/actions/pull/utils.ts +++ b/packages/cli/src/actions/pull/utils.ts @@ -14,7 +14,7 @@ import { type StringLiteral, } from '@zenstackhq/language/ast'; import type { AstFactory, ExpressionBuilder } from '@zenstackhq/language/factory'; -import { getLiteralArray, getStringLiteral } from '@zenstackhq/language/utils'; +import { getAttributeArgLiteral, getLiteralArray, getStringLiteral } from '@zenstackhq/language/utils'; import type { DataSourceProviderType } from '@zenstackhq/schema'; import type { Reference } from 'langium'; import { CliError } from '../../cli-error'; @@ -122,6 +122,19 @@ export function getRelationFkName(decl: DataField): string | undefined { return schemaAttrValue?.value; } +/** + * Gets the relation name from the @relation attribute's `name` argument. + * e.g., @relation('myRelation', fields: [...], references: [...]) -> "myRelation" + * e.g., @relation(name: 'myRelation', fields: [...], references: [...]) -> "myRelation" + * e.g., @relation(fields: [...], references: [...]) -> undefined + * e.g., @relation('backRef') -> "backRef" + */ +export function getRelationName(decl: DataField): string | undefined { + const relationAttr = decl?.attributes?.find((a) => a.decl?.ref?.name === '@relation'); + if (!relationAttr) return undefined; + return getAttributeArgLiteral(relationAttr, 'name'); +} + /** * Gets the FK field names from the @relation attribute's `fields` argument. * Returns a sorted, comma-separated string of field names for comparison. diff --git a/packages/cli/test/db/pull.test.ts b/packages/cli/test/db/pull.test.ts index 2750a2228..811c20ccf 100644 --- a/packages/cli/test/db/pull.test.ts +++ b/packages/cli/test/db/pull.test.ts @@ -152,6 +152,83 @@ model Tag { expect(restoredSchema).toEqual(schema); }); + it('should restore opposite relation fields when multiple models have FKs to the same target', async () => { + const { workDir, schema } = await createProject( + `model Comment { + id Int @id @default(autoincrement()) + text String + commentCreatedBy User? @relation('Comment_createdByToUser', fields: [createdBy], references: [id]) + createdBy Int? + commentUpdatedBy User? @relation('Comment_updatedByToUser', fields: [updatedBy], references: [id]) + updatedBy Int? +} + +model Post { + id Int @id @default(autoincrement()) + title String + postCreatedBy User? @relation('Post_createdByToUser', fields: [createdBy], references: [id]) + createdBy Int? + postUpdatedBy User? @relation('Post_updatedByToUser', fields: [updatedBy], references: [id]) + updatedBy Int? +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + commentCreatedBy Comment[] @relation('Comment_createdByToUser') + commentUpdatedBy Comment[] @relation('Comment_updatedByToUser') + postCreatedBy Post[] @relation('Post_createdByToUser') + postUpdatedBy Post[] @relation('Post_updatedByToUser') +}`, + ); + runCli('db push', workDir); + + const schemaFile = path.join(workDir, 'zenstack/schema.zmodel'); + + fs.writeFileSync(schemaFile, getDefaultPrelude()); + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + + it('should preserve opposite relation fields when multiple models have FKs to the same target', async () => { + const { workDir, schema } = await createProject( + `model Comment { + id Int @id @default(autoincrement()) + text String + commentCreatedBy User? @relation('Comment_createdByToUser', fields: [createdBy], references: [id]) + createdBy Int? + commentUpdatedBy User? @relation('Comment_updatedByToUser', fields: [updatedBy], references: [id]) + updatedBy Int? +} + +model Post { + id Int @id @default(autoincrement()) + title String + postCreatedBy User? @relation('Post_createdByToUser', fields: [createdBy], references: [id]) + createdBy Int? + postUpdatedBy User? @relation('Post_updatedByToUser', fields: [updatedBy], references: [id]) + updatedBy Int? +} + +model User { + id Int @id @default(autoincrement()) + email String @unique + commentCreatedBy Comment[] @relation('Comment_createdByToUser') + commentUpdatedBy Comment[] @relation('Comment_updatedByToUser') + postCreatedBy Post[] @relation('Post_createdByToUser') + postUpdatedBy Post[] @relation('Post_updatedByToUser') +}`, + ); + runCli('db push', workDir); + + runCli('db pull --indent 4', workDir); + + const restoredSchema = getSchema(workDir); + expect(restoredSchema).toEqual(schema); + }); + it('should restore one-to-one relation when FK is the single-column primary key', async () => { const { workDir, schema } = await createProject( `model Profile {