@dataclass
class XYZ:
atoms: polars.DataFrame
comment: t.Optional[str] = None
params: t.Dict[str, str] = field(default_factory=dict)
@staticmethod
def from_atoms(atoms: HasAtoms) -> XYZ:
params = {}
if isinstance(atoms, HasAtomCell):
coords = atoms.get_cell().to_ortho().to_linear().inner.ravel()
lattice_str = " ".join((f"{c:.8f}" for c in coords))
params['Lattice'] = lattice_str
pbc_str = " ".join(str(int(v)) for v in atoms.get_cell().pbc)
params['pbc'] = pbc_str
return XYZ(
atoms.get_atoms('local')._get_frame(),
params=params
)
@staticmethod
def from_file(file: FileOrPath) -> XYZ:
logging.info(f"Loading XYZ {file.name if hasattr(file, 'name') else file!r}...") # type: ignore
with open_file(file, 'r') as f:
try:
# TODO be more gracious about whitespace here
length = int(f.readline())
except ValueError:
raise ValueError("Error parsing XYZ file: Invalid length") from None
except IOError as e:
raise IOError(f"Error parsing XYZ file: {e}") from None
comment = f.readline().rstrip('\n')
# TODO handle if there's not a gap here
try:
params = ExtXYZParser(comment).parse()
except ValueError:
params = None
schema = _get_columns_from_params(params)
df = parse_whitespace_separated(f, schema, start_line=1)
# map atomic numbers -> symbols (on columns which are Int8)
df = df.with_columns(
get_sym(df.select(polars.col('symbol').cast(polars.Int8, strict=False)).to_series())
.fill_null(df['symbol']).alias('symbol')
)
# ensure all symbols are recognizable (this will raise ValueError if not)
get_elem(df['symbol'])
if length < len(df):
warnings.warn(f"Warning: truncating structure of length {len(df)} "
f"to match declared length of {length}")
df = df[:length]
elif length > len(df):
warnings.warn(f"Warning: structure length {len(df)} doesn't match "
f"declared length {length}.\nData could be corrupted.")
try:
params = ExtXYZParser(comment).parse()
return XYZ(df, comment, params)
except ValueError:
pass
return XYZ(df, comment)
def write(self, file: FileOrPath, fmt: XYZFormat = 'exyz'):
with open_file(file, 'w', newline='\r\n') as f:
f.write(f"{len(self.atoms)}\n")
if len(self.params) > 0 and fmt == 'exyz':
f.write(" ".join(_param_strings(self.params)))
else:
f.write(self.comment or "")
f.write("\n")
# not my best work
col_space = (3, 12, 12, 12)
f.writelines(
"".join(
f"{val:< {space}.8f}" if isinstance(val, float) else f"{val:<{space}}" for (val, space) in zip(_flatten(row), col_space)
).strip() + '\n' for row in self.atoms.select(('symbol', 'coords')).rows()
)
def cell_matrix(self) -> t.Optional[NDArray[numpy.float64]]:
if (s := self.params.get('Lattice')) is None:
return None
try:
items = list(map(float, s.split()))
if not len(items) == 9:
raise ValueError("Invalid length")
return numpy.array(items, dtype=numpy.float64).reshape((3, 3)).T
except ValueError:
warnings.warn(f"Warning: Invalid format for key 'Lattice=\"{s}\"'")
return None
def pbc(self) -> t.Optional[NDArray[numpy.bool_]]:
if (s := self.params.get('pbc')) is None:
return None
val_map = {'0': False, 'f': False, '1': True, 't': True}
try:
items = [val_map[v.lower()] for v in s.split()]
if not len(items) == 3:
raise ValueError("Invalid length")
return numpy.array(items, dtype=numpy.bool_)
except ValueError:
warnings.warn(f"Warning: Invalid format for key 'pbc=\"{s}\"'")
return None