Skip to content

Generalize with variance #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions chalk-solve/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,22 @@ impl<I: Interner> InferenceTable<I> {
.map(|p| p.assert_const_ref(interner).clone())
}

pub fn ty_root(&mut self, interner: &I, leaf: &Ty<I>) -> Option<Ty<I>> {
Some(
self.unify
.find(leaf.inference_var(interner)?)
.to_ty(interner),
)
}

pub fn lifetime_root(&mut self, interner: &I, leaf: &Lifetime<I>) -> Option<Lifetime<I>> {
Some(
self.unify
.find(leaf.inference_var(interner)?)
.to_lifetime(interner),
)
}

/// Finds the root inference var for the given variable.
///
/// The returned variable will be exactly equivalent to the given
Expand Down
215 changes: 157 additions & 58 deletions chalk-solve/src/infer/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,24 @@ impl<'t, I: Interner> Unifier<'t, I> {
T: ?Sized + Zip<I>,
{
Zip::zip_with(&mut self, variance, a, b)?;
Ok(RelationResult { goals: self.goals })
let interner = self.interner();
let mut goals = self.goals;
let table = self.table;
// Sometimes we'll produce a lifetime outlives goal which we later solve by unification
// Technically, these *will* get canonicalized to the same bound var and so that will end up
// as a goal like `^0.0 <: ^0.0`, which is trivially true. But, we remove those *here*, which
// might help caching.
goals.retain(|g| match g.goal.data(interner) {
GoalData::SubtypeGoal(SubtypeGoal { a, b }) => {
let n_a = table.ty_root(interner, a);
let n_b = table.ty_root(interner, b);
let a = n_a.as_ref().unwrap_or(a);
let b = n_b.as_ref().unwrap_or(b);
a != b
}
_ => true,
});
Ok(RelationResult { goals })
}

/// Relate `a`, `b` with the variance such that if `variance = Covariant`, `a` is
Expand Down Expand Up @@ -473,85 +490,128 @@ impl<'t, I: Interner> Unifier<'t, I> {
}

#[instrument(level = "debug", skip(self))]
fn generalize_ty(&mut self, ty: &Ty<I>, universe_index: UniverseIndex) -> Ty<I> {
fn generalize_ty(
&mut self,
ty: &Ty<I>,
universe_index: UniverseIndex,
variance: Variance,
) -> Ty<I> {
let interner = self.interner;
match ty.kind(interner) {
TyKind::Adt(id, substitution) => TyKind::Adt(
*id,
self.generalize_substitution(substitution, universe_index),
)
.intern(interner),
TyKind::Adt(id, substitution) => {
let variances = if matches!(variance, Variance::Invariant) {
None
} else {
Some(self.unification_database().adt_variance(*id))
};
let get_variance = |i| {
variances
.as_ref()
.map(|v| v.as_slice(interner)[i])
.unwrap_or(Variance::Invariant)
};
TyKind::Adt(
*id,
self.generalize_substitution(substitution, universe_index, get_variance),
)
.intern(interner)
}
TyKind::AssociatedType(id, substitution) => TyKind::AssociatedType(
*id,
self.generalize_substitution(substitution, universe_index),
self.generalize_substitution(substitution, universe_index, |_| variance),
)
.intern(interner),
TyKind::Scalar(scalar) => TyKind::Scalar(*scalar).intern(interner),
TyKind::Str => TyKind::Str.intern(interner),
TyKind::Tuple(arity, substitution) => TyKind::Tuple(
*arity,
self.generalize_substitution(substitution, universe_index),
self.generalize_substitution(substitution, universe_index, |_| variance),
)
.intern(interner),
TyKind::OpaqueType(id, substitution) => TyKind::OpaqueType(
*id,
self.generalize_substitution(substitution, universe_index),
self.generalize_substitution(substitution, universe_index, |_| variance),
)
.intern(interner),
TyKind::Slice(ty) => {
TyKind::Slice(self.generalize_ty(ty, universe_index)).intern(interner)
TyKind::Slice(self.generalize_ty(ty, universe_index, variance)).intern(interner)
}
TyKind::FnDef(id, substitution) => {
let variances = if matches!(variance, Variance::Invariant) {
None
} else {
Some(self.unification_database().fn_def_variance(*id))
};
let get_variance = |i| {
variances
.as_ref()
.map(|v| v.as_slice(interner)[i])
.unwrap_or(Variance::Invariant)
};
TyKind::FnDef(
*id,
self.generalize_substitution(substitution, universe_index, get_variance),
)
.intern(interner)
}
TyKind::Ref(mutability, lifetime, ty) => {
let lifetime_variance = variance.xform(Variance::Contravariant);
let ty_variance = match mutability {
Mutability::Not => Variance::Covariant,
Mutability::Mut => Variance::Invariant,
};
TyKind::Ref(
*mutability,
self.generalize_lifetime(lifetime, universe_index, lifetime_variance),
self.generalize_ty(ty, universe_index, ty_variance),
)
.intern(interner)
}
TyKind::FnDef(id, substitution) => TyKind::FnDef(
*id,
self.generalize_substitution(substitution, universe_index),
)
.intern(interner),
TyKind::Ref(mutability, lifetime, ty) => TyKind::Ref(
*mutability,
self.generalize_lifetime(lifetime, universe_index),
self.generalize_ty(ty, universe_index),
)
.intern(interner),
TyKind::Raw(mutability, ty) => {
TyKind::Raw(*mutability, self.generalize_ty(ty, universe_index)).intern(interner)
let ty_variance = match mutability {
Mutability::Not => Variance::Covariant,
Mutability::Mut => Variance::Invariant,
};
TyKind::Raw(
*mutability,
self.generalize_ty(ty, universe_index, ty_variance),
)
.intern(interner)
}
TyKind::Never => TyKind::Never.intern(interner),
TyKind::Array(ty, const_) => TyKind::Array(
self.generalize_ty(ty, universe_index),
self.generalize_ty(ty, universe_index, variance),
self.generalize_const(const_, universe_index),
)
.intern(interner),
TyKind::Closure(id, substitution) => TyKind::Closure(
*id,
self.generalize_substitution(substitution, universe_index),
self.generalize_substitution(substitution, universe_index, |_| variance),
)
.intern(interner),
TyKind::Generator(id, substitution) => TyKind::Generator(
*id,
self.generalize_substitution(substitution, universe_index),
self.generalize_substitution(substitution, universe_index, |_| variance),
)
.intern(interner),
TyKind::GeneratorWitness(id, substitution) => TyKind::GeneratorWitness(
*id,
self.generalize_substitution(substitution, universe_index),
self.generalize_substitution(substitution, universe_index, |_| variance),
)
.intern(interner),
TyKind::Foreign(id) => TyKind::Foreign(*id).intern(interner),
TyKind::Error => TyKind::Error.intern(interner),
TyKind::Dyn(dyn_ty) => {
let DynTy {
bounds,
lifetime: _,
} = dyn_ty;
let lifetime_var = self.table.new_variable(universe_index);
let lifetime = lifetime_var.to_lifetime(interner);
let DynTy { bounds, lifetime } = dyn_ty;
let lifetime = self.generalize_lifetime(
lifetime,
universe_index,
variance.xform(Variance::Contravariant),
);

let bounds = bounds.map_ref(|value| {
//let universe_index = universe_index.next();
let iter = value.iter(interner).map(|sub_var| {
sub_var.map_ref(|clause| {
//let universe_index = universe_index.next();
// let universe_index = self.table.new_universe();
match clause {
WhereClause::Implemented(trait_ref) => {
let TraitRef {
Expand All @@ -561,6 +621,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
let substitution = self.generalize_substitution_skip_self(
substitution,
universe_index,
|_| Some(variance),
);
WhereClause::Implemented(TraitRef {
substitution,
Expand All @@ -578,6 +639,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
let substitution = self.generalize_substitution(
substitution,
universe_index,
|_| variance,
);
AliasTy::Opaque(OpaqueTy {
substitution,
Expand All @@ -598,6 +660,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
let substitution = self.generalize_substitution(
substitution,
universe_index,
|_| variance,
);
AliasTy::Projection(ProjectionTy {
substitution,
Expand Down Expand Up @@ -637,8 +700,25 @@ impl<'t, I: Interner> Unifier<'t, I> {
ref substitution,
} = *fn_ptr;

let substitution =
FnSubst(self.generalize_substitution(&substitution.0, universe_index));
let len = substitution.0.len(interner);
let vars = substitution.0.iter(interner).enumerate().map(|(i, var)| {
if i < len - 1 {
self.generalize_generic_var(
var,
universe_index,
variance.xform(Variance::Contravariant),
)
} else {
self.generalize_generic_var(
substitution.0.as_slice(interner).last().unwrap(),
universe_index,
variance,
)
}
});

let substitution = FnSubst(Substitution::from_iter(interner, vars));

TyKind::Function(FnPointer {
num_binders,
sig,
Expand All @@ -660,7 +740,9 @@ impl<'t, I: Interner> Unifier<'t, I> {
if matches!(kind, TyVariableKind::Integer | TyVariableKind::Float) {
ty.clone()
} else if let Some(ty) = self.table.normalize_ty_shallow(interner, ty) {
self.generalize_ty(&ty, universe_index)
self.generalize_ty(&ty, universe_index, variance)
} else if matches!(variance, Variance::Invariant) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rustc code here also checks that this variable can name universe_index here, although that's maybe only for error messages.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. So should I change this here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can keep the simpler code for now.

ty.clone()
} else {
let ena_var = self.table.new_variable(universe_index);
ena_var.to_ty(interner)
Expand All @@ -674,15 +756,20 @@ impl<'t, I: Interner> Unifier<'t, I> {
&mut self,
lifetime: &Lifetime<I>,
universe_index: UniverseIndex,
variance: Variance,
) -> Lifetime<I> {
let interner = self.interner;
match lifetime.data(&interner) {
LifetimeData::BoundVar(_) => {
return lifetime.clone();
}
_ => {
let ena_var = self.table.new_variable(universe_index);
ena_var.to_lifetime(interner)
if matches!(variance, Variance::Invariant) {
lifetime.clone()
} else {
let ena_var = self.table.new_variable(universe_index);
ena_var.to_lifetime(interner)
}
}
}
}
Expand All @@ -706,13 +793,16 @@ impl<'t, I: Interner> Unifier<'t, I> {
&mut self,
sub_var: &GenericArg<I>,
universe_index: UniverseIndex,
variance: Variance,
) -> GenericArg<I> {
let interner = self.interner;
(match sub_var.data(interner) {
GenericArgData::Ty(ty) => GenericArgData::Ty(self.generalize_ty(ty, universe_index)),
GenericArgData::Lifetime(lifetime) => {
GenericArgData::Lifetime(self.generalize_lifetime(lifetime, universe_index))
GenericArgData::Ty(ty) => {
GenericArgData::Ty(self.generalize_ty(ty, universe_index, variance))
}
GenericArgData::Lifetime(lifetime) => GenericArgData::Lifetime(
self.generalize_lifetime(lifetime, universe_index, variance),
),
GenericArgData::Const(const_value) => {
GenericArgData::Const(self.generalize_const(const_value, universe_index))
}
Expand All @@ -721,32 +811,37 @@ impl<'t, I: Interner> Unifier<'t, I> {
}

/// Generalizes all but the first
#[instrument(level = "debug", skip(self))]
fn generalize_substitution_skip_self(
#[instrument(level = "debug", skip(self, get_variance))]
fn generalize_substitution_skip_self<F: Fn(usize) -> Option<Variance>>(
&mut self,
substitution: &Substitution<I>,
universe_index: UniverseIndex,
get_variance: F,
) -> Substitution<I> {
let interner = self.interner;
let vars = substitution.iter(interner).take(1).cloned().chain(
substitution
.iter(interner)
.skip(1)
.map(|sub_var| self.generalize_generic_var(sub_var, universe_index)),
);
let vars = substitution.iter(interner).enumerate().map(|(i, sub_var)| {
if i == 0 {
sub_var.clone()
} else {
let variance = get_variance(i).unwrap_or(Variance::Invariant);
self.generalize_generic_var(sub_var, universe_index, variance)
}
});
Substitution::from_iter(interner, vars)
}

#[instrument(level = "debug", skip(self))]
fn generalize_substitution(
#[instrument(level = "debug", skip(self, get_variance))]
fn generalize_substitution<F: Fn(usize) -> Variance>(
&mut self,
substitution: &Substitution<I>,
universe_index: UniverseIndex,
get_variance: F,
) -> Substitution<I> {
let interner = self.interner;
let vars = substitution
.iter(interner)
.map(|sub_var| self.generalize_generic_var(sub_var, universe_index));
let vars = substitution.iter(interner).enumerate().map(|(i, sub_var)| {
let variance = get_variance(i);
self.generalize_generic_var(sub_var, universe_index, variance)
});

Substitution::from_iter(interner, vars)
}
Expand Down Expand Up @@ -822,7 +917,7 @@ impl<'t, I: Interner> Unifier<'t, I> {
// this, we create two new vars `'0` and `1`. Then we relate `var` with
// `&'0 1` and `&'0 1` with `&'x SomeType`. The second relation will
// recurse, and we'll end up relating `'0` with `'x` and `1` with `SomeType`.
let generalized_val = self.generalize_ty(&ty1, universe_index);
let generalized_val = self.generalize_ty(&ty1, universe_index, variance);

debug!("var {:?} generalized to {:?}", var, generalized_val);

Expand Down Expand Up @@ -1259,6 +1354,10 @@ where
// become the value of).
InferenceValue::Unbound(ui) => {
if self.unifier.table.unify.unioned(var, self.var) {
debug!(
"OccursCheck aborting because {:?} unioned with {:?}",
var, self.var,
);
return Err(NoSolution);
}

Expand Down
Loading